决策树实验

Python应用 专栏收录该内容
3 篇文章 0 订阅

题目

Kaggle的一个比赛题目,被拿来当日常作业了。

原题目链接

做法

  1. 先把数据集转成数字形式的,而且从0开始编号,这样好处理。
  2. (为了简化问题完成作业)只选择一部分取值较好处理的变量做自变量。
  3. 训练ID3决策树,并记录。
  4. 采用ID3决策树进行预测。

实现细节

  1. train和test中相同元素转化的数字一定要一样。
  2. 训练的时候元素取值可能不全面,需要补充默认值。
  3. convert.py负责处理数据,ID3.py负责跑决策树

代码

"""
convert.py
"""
# convert the data to enumber data
DEBUG = False
DEV = False

def open_data(path):
    data_origin = []
    with open(path) as f:
        data_origin = f.readlines()
    return [x.strip() for x in data_origin]

def split_data(origin_data):
    header = origin_data[0]
    data = origin_data[1:]

    # convert to [ [name, name], ... ]
    header = header.split(',')
    data = [x.split(',') for x in data]
    data = [line[:3] + [line[3] + ',' + line[4]] + line[5:] for line in data]
    return header, data

occure_to_id = None
def convert_data(data, header=None):
    global occure_to_id
    ORIGIN_MODE = 0
    AGE_SPLIT_MODE = 1
    AVERAGE_MODE = 2
    
    mode = AGE_SPLIT_MODE
    new_data = []

    if mode == ORIGIN_MODE or mode == AVERAGE_MODE:
        # rember the count of different element of each line.
        if occure_to_id == None:
            occure_to_id = [{'count': -1, '': -1} for i in range(len(data[0]))]
        
        # for each line, convert it to number record.
        for data_line in data:
            new_data_line = []
            for idx, element in enumerate(data_line):
                # give new number of new element.
                if element not in occure_to_id[idx]:
                    occure_to_id[idx]['count'] += 1
                    occure_to_id[idx][element] = occure_to_id[idx]['count']
                # send number to the new data.
                new_data_line.append(occure_to_id[idx][element])
            # record new data line.
            new_data.append(new_data_line)

    elif mode == AGE_SPLIT_MODE:
        # find the age index, and create the count record.
        age_idx = header.index('Age')
        if occure_to_id == None:
            occure_to_id = [{'count': -1, '': -1} for i in range(len(data[0]))]

        # same of top.
        for data_line in data:
            new_data_line = []
            for idx, element in enumerate(data_line):
                # age number divide.
                if idx == age_idx:
                    # empty data.
                    if element == '':
                        element = '0'
                    element = int(float(element)//20)
                    occure_to_id[idx][element] = element
                
                # same of top
                if element not in occure_to_id[idx]:
                    occure_to_id[idx]['count'] += 1
                    occure_to_id[idx][element] = occure_to_id[idx]['count']
                # same of top.
                new_data_line.append(occure_to_id[idx][element])
            
            new_data.append(new_data_line)
        
        # DEBUG the age count.
        DEBUG_OUTPUT(occure_to_id[age_idx])

    DEBUG_OUTPUT([x['count'] for x in occure_to_id])
    return new_data


def DEBUG_OUTPUT(data):
    if not DEBUG:
        return

    if type(data) == type([]):
        for idx, x in enumerate(data[:15]):
            print(idx, ':', x)
    else:
        print(data)

def select_data(header, data):
    # build useful name set
    name_set = ['Survived', 'Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Embarked']
    # select.
    idx_list = []
    use_header = []
    for idx, name in enumerate(header):
        if name in name_set:
            idx_list.append(idx)
            use_header.append(name)
    
    use_data = []
    for line in data:
        new_line = []
        for use_id in idx_list:
            new_line.append(line[use_id])
        use_data.append(new_line)
    

    return use_header, use_data


def train(input_file, output_file):
    origin_data = open_data(input_file)

    # DEBUG
    DEBUG_OUTPUT(origin_data)
    # END

    header, data = split_data(origin_data)
    
    # DEBUG
    DEBUG_OUTPUT([x[4] for x in data])
    DEBUG_OUTPUT([x[5] for x in data])
    DEBUG_OUTPUT(header[5])
    DEBUG_OUTPUT(header)
    DEBUG_OUTPUT(data)
    # END

    use_data = convert_data(data, header)
    use_header, use_data = select_data(header, use_data)

    # DEBUG
    DEBUG_OUTPUT(use_data)
    # END

    with open(output_file, 'w') as f:
        f.write(','.join([str(x) for x in use_header]) + '\n')
        for line in use_data:
            f.write(','.join([str(x) for x in line]) + '\n')
    
    DEBUG_OUTPUT(open_data(output_file))
    
def freeze_convert(data, header=None):
    global occure_to_id
    ORIGIN_MODE = 0
    AGE_SPLIT_MODE = 1
    AVERAGE_MODE = 2
    
    mode = AGE_SPLIT_MODE
    new_data = []

    if mode == ORIGIN_MODE or mode == AVERAGE_MODE:
        # rember the count of different element of each line.
        if occure_to_id == None:
            exit(-1)
        
        # for each line, convert it to number record.
        for data_line in data:
            new_data_line = []
            for idx, element in enumerate(data_line):
                # give new number of new element.
                if element not in occure_to_id[idx]:
                    element = ''
                # send number to the new data.
                new_data_line.append(occure_to_id[idx][element])
            # record new data line.
            new_data.append(new_data_line)

    elif mode == AGE_SPLIT_MODE:
        # find the age index, and create the count record.
        age_idx = header.index('Age')
        if occure_to_id == None:
            exit(-1)

        # same of top.
        for data_line in data:
            new_data_line = []
            for idx, element in enumerate(data_line):
                # age number divide.
                if idx == age_idx:
                    # empty data.
                    if element == '':
                        element = '0'
                    element = int(float(element)//20)
                    occure_to_id[idx][element] = element
                
                # same of top
                if element not in occure_to_id[idx]:
                    element = ''
                # same of top.
                new_data_line.append(occure_to_id[idx][element])
            
            new_data.append(new_data_line)
        
        # DEBUG the age count.
        DEBUG_OUTPUT(occure_to_id[age_idx])

    DEBUG_OUTPUT([x['count'] for x in occure_to_id])
    return new_data

def test(input_file, output_file):
    origin_data = open_data(input_file)

    # DEBUG
    DEBUG_OUTPUT(origin_data)
    # END

    header, data = split_data(origin_data)

    use_data = convert_data(data, header)
    use_header, use_data = select_data(header, use_data)

    with open(output_file, 'w') as f:
        f.write(','.join([str(x) for x in use_header]) + '\n')
        for line in use_data:
            f.write(','.join([str(x) for x in line]) + '\n')
    
    DEBUG_OUTPUT(open_data(output_file))

if __name__ == '__main__':
    train('./train.csv', './train_num.csv')
    test('./test.csv', './test_num.csv')

"""
ID3.py
这个写法可能跑的很慢...有时间优化一下
"""
# convert the data to enumber data
import math
DEBUG = False
DEV = False
tolerance = 0.0

def open_data(path):
    data_origin = []
    with open(path) as f:
        data_origin = f.readlines()
    return [x.strip() for x in data_origin]

def split_data(origin_data):
    header = origin_data[0]
    data = origin_data[1:]

    # convert to [ [name, name], ... ]
    header = header.split(',')
    data = [x.split(',') for x in data]
    return header, data

def calc_SHANG(data, target_idx):
    # count = [postivite_count, negative_count]
    count = [0, 0]

    # count the number.
    for data_line in data:
        target_num = int(data_line[target_idx])
        # print(target_num)
        count[target_num] += 1
    
    # if all the element is one.
    if min(count) == 0:
        return 0.0

    # calc the SHANG.
    sum_count = sum(count)
    count = [x/sum_count for x in count]

    neg_SHANG = 0.0
    for x in count:
        neg_SHANG += x * math.log2(x)
    SHANG = -neg_SHANG

    return SHANG

def most_label(data, target_idx):
    # count = [postivite_count, negative_count]
    count = [0, 0]

    # count the number.
    for data_line in data:
        target_num = int(data_line[target_idx])
        count[target_num] += 1
    
    # find the most label
    return count.index(max(count))

function_count = 0
function_list = []
def ID3(data, target_idx, dep=0):
    # print(target_idx)
    # different_value = set(line[target_idx] for line in data)
    # print(different_value)

    # ID3 function, target_idx is the index that should be predict.
    global function_list
    global function_count
    global tolerance

    this_function_count = function_count
    function_list.append(lambda x: -1)
    function_count += 1

    label_now = most_label(data, target_idx)

    # if data is None, return DONTKNOW!
    if len(data) == 0:
        function_list[this_function_count] = lambda x: -1
        return function_list[this_function_count]

    # calc the SHANG of now.
    now_SHANG = calc_SHANG(data, target_idx)
    if now_SHANG == 0:
        function_list[this_function_count] = lambda x: label_now
        return function_list[this_function_count]

    # enumrate the choice of now, calculate SHANG
    SHANG_list = [0 for i in range(len(data[0]))]
    for idx in range(len(data[0])):
        if idx == target_idx:
            SHANG_list[idx] = now_SHANG
        else:
            data_len = len(data)
            total_SHANG = 0.0
            different_value = set(line[idx] for line in data)
            for element in different_value:
                new_data = list(filter(lambda x: x[idx] == element, data))
                total_SHANG += calc_SHANG(new_data, target_idx) * (len(new_data) / data_len)
            SHANG_list[idx] = total_SHANG
    
    # select the min one.
    min_SHANG = min(SHANG_list)
    delta = now_SHANG - min_SHANG

    if delta <= tolerance: # tolerance of worise.
        function_list[this_function_count] = lambda x: label_now
        return function_list[this_function_count]
    
    min_index = SHANG_list.index(min_SHANG)
    
    # build return function list.
    different_value = set(line[min_index] for line in data)
    func_list = {}
    for element in different_value:
        new_data = list(filter(lambda x: x[min_index] == element, data))
        func_list[element] = ID3(new_data, target_idx, dep+1)
    

    function_list[this_function_count] = lambda x: func_list.get(x[min_index], lambda x: label_now)(x)
    return function_list[this_function_count]


def main(input_path):
    origin_data = open_data(input_path)
    header, data = split_data(origin_data)

    ID3(data, target_idx=header.index('Survived'))
    
    recall_matrix = [[0, 0], [0, 0]]
    for idx, data_line in enumerate(data):
        predict = function_list[0](data_line)
        # print('predict is ', predict, ' type is ', type(predict))
        ans = int(data_line[header.index('Survived')])
        # print(idx, ' predict ans : ', predict )
        recall_matrix[predict][ans] += 1
    
    print('----------------------------------------')
    print('Recall Matrix : (tolerance: %.4f)'%(tolerance))
    print(recall_matrix[0], recall_matrix[1], sep='\n')
    print('----------------------------------------')
    return function_list[0]

def test(input_path, predict_func):
    origin_data = open_data(input_path)
    header, data = split_data(origin_data)
    targets = []

    for idx, data_line in enumerate(data):
        predict = predict_func(data_line)
        targets.append(predict)
        print(idx, ' : ', predict)
    
    return targets

if __name__ == '__main__':
    predict_func = main('./train_num.csv')
    targets = test('./test_num.csv', predict_func)
    with open('ans.csv', 'w') as f:
        for target in targets:
            f.write(str(target) + '\n')
  • 0
    点赞
  • 0
    评论
  • 0
    收藏
  • 一键三连
    一键三连
  • 扫一扫,分享海报

打赏
文章很值,打赏犒劳作者一下
相关推荐
©️2020 CSDN 皮肤主题: 点我我会动 设计师:白松林 返回首页

打赏

SofanHe

你的鼓励将是我创作的最大动力

¥2 ¥4 ¥6 ¥10 ¥20
输入1-500的整数
余额支付 (余额:-- )
扫码支付
扫码支付:¥2
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。

余额充值