本文实例为大家分享了基于信息增益的决策树归纳的Python实现代码,供大家参考,具体内容如下
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
|
# -*- coding: utf-8 -*- import numpy as np import matplotlib.mlab as mlab import matplotlib.pyplot as plt from copy import copy #加载训练数据 #文件格式:属性标号,是否连续【yes|no】,属性说明 attribute_file_dest = 'F:\\bayes_categorize\\attribute.dat' attribute_file = open (attribute_file_dest) #文件格式:rec_id,attr1_value,attr2_value,...,attrn_value,class_id trainning_data_file_dest = 'F:\\bayes_categorize\\trainning_data.dat' trainning_data_file = open (trainning_data_file_dest) #文件格式:class_id,class_desc class_desc_file_dest = 'F:\\bayes_categorize\\class_desc.dat' class_desc_file = open (class_desc_file_dest) root_attr_dict = {} for line in attribute_file : line = line.strip() fld_list = line.split( ',' ) root_attr_dict[ int (fld_list[ 0 ])] = tuple (fld_list[ 1 :]) class_dict = {} for line in class_desc_file : line = line.strip() fld_list = line.split( ',' ) class_dict[ int (fld_list[ 0 ])] = fld_list[ 1 ] trainning_data_dict = {} class_member_set_dict = {} for line in trainning_data_file : line = line.strip() fld_list = line.split( ',' ) rec_id = int (fld_list[ 0 ]) a1 = int (fld_list[ 1 ]) a2 = int (fld_list[ 2 ]) a3 = float (fld_list[ 3 ]) c_id = int (fld_list[ 4 ]) if c_id not in class_member_set_dict : class_member_set_dict[c_id] = set () class_member_set_dict[c_id].add(rec_id) trainning_data_dict[rec_id] = (a1 , a2 , a3 , c_id) attribute_file.close() class_desc_file.close() trainning_data_file.close() class_possibility_dict = {} for c_id in class_member_set_dict : class_possibility_dict[c_id] = ( len (class_member_set_dict[c_id]) + 0.0 ) / len (trainning_data_dict) #等待分类的数据 data_to_classify_file_dest = 'F:\\bayes_categorize\\trainning_data_new.dat' data_to_classify_file = open (data_to_classify_file_dest) data_to_classify_dict = {} for line in data_to_classify_file : line = line.strip() fld_list = line.split( ',' ) rec_id = int (fld_list[ 0 ]) a1 = int (fld_list[ 1 ]) a2 = int (fld_list[ 2 ]) a3 = float (fld_list[ 3 ]) c_id = int (fld_list[ 4 ]) data_to_classify_dict[rec_id] = (a1 , a2 , a3 , c_id) data_to_classify_file.close() ''' 决策树的表达 结点的需求: 1、指示出是哪一种分区 一共3种 一是离散穷举 二是连续有分裂点 三是离散有判别集合 零是叶子结点 2、保存分类所需信息 3、子结点列表 每个结点用Tuple类型表示 元素一是整形,取值123 分别对应两种分裂类型 元素二是集合类型 对于1保存所有的离散值 对于2保存分裂点 对于3保存判别集合 对于0保存分类结果类标号 元素三是dict key对于1来说是某个的离散值 对于23来说只有12两种 对于2来说1代表小于等于分裂点 对于3来说1代表属于判别集合 ''' #对于一个成员列表,计算其熵 #公式为 Info_D = - sum(pi * log2 (pi)) pi为一个元素属于Ci的概率,用|Ci|/|D|计算 ,对所有分类求和 def get_entropy( member_list ) : #成员总数 mem_cnt = len (member_list) #首先找出member中所包含的分类 class_dict = {} for mem_id in member_list : c_id = trainning_data_dict[mem_id][ 3 ] if c_id not in class_dict : class_dict[c_id] = set () class_dict[c_id].add(mem_id) tmp_sum = 0.0 for c_id in class_dict : pi = ( len (class_dict[c_id]) + 0.0 ) / mem_cnt tmp_sum + = pi * mlab.log2(pi) tmp_sum = - tmp_sum return tmp_sum def attribute_selection_method( member_list , attribute_dict ) : #先计算原始的熵 info_D = get_entropy(member_list) max_info_Gain = 0.0 attr_get = 0 split_point = 0.0 for attr_id in attribute_dict : #对于每一个属性计算划分后的熵 #信息增益等于原始的熵减去划分后的熵 info_D_new = 0 #如果是连续属性 if attribute_dict[attr_id][ 0 ] = = 'yes' : #先得到memberlist中此属性的取值序列,把序列中每一对相邻项的中值作为划分点计算熵 #找出其中最小的,作为此连续属性的划分点 value_list = [] for mem_id in member_list : value_list.append(trainning_data_dict[mem_id][attr_id - 1 ]) #获取相邻元素的中值序列 mid_value_list = [] value_list.sort() #print value_list last_value = None for value in value_list : if value = = last_value : continue if last_value is not None : mid_value_list.append((last_value + value) / 2 ) last_value = value #print mid_value_list #对于中值序列做循环 #计算以此值做为划分点的熵 #总的熵等于两个划分的熵乘以两个划分的比重 min_info = 1000000000.0 total_mens = len (member_list) + 0.0 for mid_value in mid_value_list : #小于mid_value的mem less_list = [] #大于 more_list = [] for tmp_mem_id in member_list : if trainning_data_dict[tmp_mem_id][attr_id - 1 ] < = mid_value : less_list.append(tmp_mem_id) else : more_list.append(tmp_mem_id) sum_info = len (less_list) / total_mens * get_entropy(less_list) \ + len (more_list) / total_mens * get_entropy(more_list) if sum_info < min_info : min_info = sum_info split_point = mid_value info_D_new = min_info #如果是离散属性 else : #计算划分后的熵 #采用循环累加的方式 attr_value_member_dict = {} #键为attribute value , 值为memberlist for tmp_mem_id in member_list : attr_value = trainning_data_dict[tmp_mem_id][attr_id - 1 ] if attr_value not in attr_value_member_dict : attr_value_member_dict[attr_value] = [] attr_value_member_dict[attr_value].append(tmp_mem_id) #将每个离散值的熵乘以比重加到这上面 total_mens = len (member_list) + 0.0 sum_info = 0.0 for a_value in attr_value_member_dict : sum_info + = len (attr_value_member_dict[a_value]) / total_mens \ * get_entropy(attr_value_member_dict[a_value]) info_D_new = sum_info info_Gain = info_D - info_D_new if info_Gain > max_info_Gain : max_info_Gain = info_Gain attr_get = attr_id #如果是离散的 #print 'attr_get ' + str(attr_get) if attribute_dict[attr_get][ 0 ] = = 'no' : return ( 1 , attr_get , split_point) else : return ( 2 , attr_get , split_point) #第三类先不考虑 def get_decision_tree(father_node , key , member_list , attr_dict ) : #最终的结果是新建一个结点,并且添加到father_node的sub_node_dict,对key为键 #检查memberlist 如果都是同类的,则生成一个叶子结点,set里面保存类标号 class_set = set () for mem_id in member_list : class_set.add(trainning_data_dict[mem_id][ 3 ]) if len (class_set) = = 1 : father_node[ 2 ][key] = ( 0 , ( 1 , class_set) , {} ) return #检查attribute_list,如果为空,产生叶子结点,类标号为memberlist中多数元素的类标号 #如果几个类的成员等量,则打印提示,并且全部添加到set里面 if not attr_dict : class_cnt_dict = {} for mem_id in member_list : c_id = trainning_data_dict[mem_id][ 3 ] if c_id not in class_cnt_dict : class_cnt_dict[c_id] = 1 else : class_cnt_dict[c_id] + = 1 class_set = set () max_cnt = 0 for c_id in class_cnt_dict : if class_cnt_dict[c_id] > max_cnt : max_cnt = class_cnt_dict[c_id] class_set.clear() class_set.add(c_id) elif class_cnt_dict[c_id] = = max_cnt : class_set.add(c_id) if len (class_set) > 1 : print 'more than one class !' father_node[ 2 ][key] = ( 0 , ( 1 , class_set ) , {} ) return #找出最好的分区方案 , 暂不考虑第三种划分方法 #比较所有离散属性和所有连续属性的所有中值点划分的信息增益 split_criterion = attribute_selection_method(member_list , attr_dict) #print split_criterion selected_plan_id = split_criterion[ 0 ] selected_attr_id = split_criterion[ 1 ] #如果采用的是离散属性做为分区方案,删除这个属性 new_attr_dict = copy(attr_dict) if attr_dict[selected_attr_id][ 0 ] = = 'no' : del new_attr_dict[selected_attr_id] #建立一个结点new_node,father_node[2][key] = new_node #然后对new node的每一个key , sub_member_list, #调用 get_decision_tree(new_node , new_key , sub_member_list , new_attribute_dict) #实现递归 ele2 = ( selected_attr_id , set () ) #如果是1 , ele2保存所有离散值 if selected_plan_id = = 1 : for mem_id in member_list : ele2[ 1 ].add(trainning_data_dict[mem_id][selected_attr_id - 1 ]) #如果是2,ele2保存分裂点 elif selected_plan_id = = 2 : ele2[ 1 ].add(split_criterion[ 2 ]) #如果是3则保存判别集合,先不管 else : print 'not completed' pass new_node = ( selected_plan_id , ele2 , {} ) father_node[ 2 ][key] = new_node #生成KEY,并递归调用 if selected_plan_id = = 1 : #每个attr_value是一个key attr_value_member_dict = {} for mem_id in member_list : attr_value = trainning_data_dict[mem_id][selected_attr_id - 1 ] if attr_value not in attr_value_member_dict : attr_value_member_dict[attr_value] = [] attr_value_member_dict[attr_value].append(mem_id) for attr_value in attr_value_member_dict : get_decision_tree(new_node , attr_value , attr_value_member_dict[attr_value] , new_attr_dict) pass elif selected_plan_id = = 2 : #key 只有12 , 小于等于分裂点的是1 , 大于的是2 less_list = [] more_list = [] for mem_id in member_list : attr_value = trainning_data_dict[mem_id][selected_attr_id - 1 ] if attr_value < = split_criterion[ 2 ] : less_list.append(mem_id) else : more_list.append(mem_id) #if len(less_list) != 0 : get_decision_tree(new_node , 1 , less_list , new_attr_dict) #if len(more_list) != 0 : get_decision_tree(new_node , 2 , more_list , new_attr_dict) pass #如果是3则保存判别集合,先不管 else : print 'not completed' pass def get_class_sub(node , tp ) : # attr_id = node[ 1 ][ 0 ] plan_id = node[ 0 ] key = 0 if plan_id = = 0 : return node[ 1 ][ 1 ] elif plan_id = = 1 : key = tp[attr_id - 1 ] elif plan_id = = 2 : split_point = tuple (node[ 1 ][ 1 ])[ 0 ] attr_value = tp[attr_id - 1 ] if attr_value < = split_point : key = 1 else : key = 2 else : print 'error' return set () return get_class_sub(node[ 2 ][key] , tp ) def get_class(r_node , tp) : #tp为一组属性值 if r_node[ 0 ] ! = - 1 : print 'error' return set () if 1 in r_node[ 2 ] : return get_class_sub(r_node[ 2 ][ 1 ] , tp) else : print 'error' return set () if __name__ = = '__main__' : root_node = ( - 1 , set () , {} ) mem_list = trainning_data_dict.keys() get_decision_tree(root_node , 1 , mem_list , root_attr_dict ) #测试分类器的准确率 diff_cnt = 0 for mem_id in data_to_classify_dict : c_id = get_class(root_node , data_to_classify_dict[mem_id][ 0 : 3 ]) if tuple (c_id)[ 0 ] ! = data_to_classify_dict[mem_id][ 3 ] : print tuple (c_id)[ 0 ] print data_to_classify_dict[mem_id][ 3 ] print 'different' diff_cnt + = 1 print diff_cnt |
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持服务器之家。
原文链接:https://blog.csdn.net/conggova/article/details/77528966