服务器之家

服务器之家 > 正文

python实现决策树分类

时间:2021-03-31 00:08     来源/作者:momaojia

上一篇博客主要介绍了决策树的原理,这篇主要介绍他的实现,代码环境python 3.4,实现的是id3算法,首先为了后面matplotlib的绘图方便,我把原来的中文数据集变成了英文。

原始数据集:

python实现决策树分类

变化后的数据集在程序代码中体现,这就不截图了

构建决策树的代码如下:

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#coding :utf-8
'''
2017.6.25 author :erin
   function: "decesion tree" id3
   
'''
import numpy as np
import pandas as pd
from math import log
import operator
def load_data():
 
 #data=np.array(data)
 data=[['teenager' ,'high', 'no' ,'same', 'no'],
   ['teenager', 'high', 'no', 'good', 'no'],
   ['middle_aged' ,'high', 'no', 'same', 'yes'],
   ['old_aged', 'middle', 'no' ,'same', 'yes'],
   ['old_aged', 'low', 'yes', 'same' ,'yes'],
   ['old_aged', 'low', 'yes', 'good', 'no'],
   ['middle_aged', 'low' ,'yes' ,'good', 'yes'],
   ['teenager' ,'middle' ,'no', 'same', 'no'],
   ['teenager', 'low' ,'yes' ,'same', 'yes'],
   ['old_aged' ,'middle', 'yes', 'same', 'yes'],
   ['teenager' ,'middle', 'yes', 'good', 'yes'],
   ['middle_aged' ,'middle', 'no', 'good', 'yes'],
   ['middle_aged', 'high', 'yes', 'same', 'yes'],
   ['old_aged', 'middle', 'no' ,'good' ,'no']]
 features=['age','input','student','level']
 return data,features
 
def cal_entropy(dataset):
 '''
 输入data ,表示带最后标签列的数据集
 计算给定数据集总的信息熵
 {'是': 9, '否': 5}
 0.9402859586706309
 '''
 
 numentries = len(dataset)
 labelcounts = {}
 for featvec in dataset:
  label = featvec[-1]
  if label not in labelcounts.keys():
   labelcounts[label] = 0
  labelcounts[label] += 1
 entropy = 0.0
 for key in labelcounts.keys():
  p_i = float(labelcounts[key]/numentries)
  entropy -= p_i * log(p_i,2)#log(x,10)表示以10 为底的对数
 return entropy
 
def split_data(data,feature_index,value):
 '''
 划分数据集
 feature_index:用于划分特征的列数,例如“年龄”
 value:划分后的属性值:例如“青少年”
 '''
 data_split=[]#划分后的数据集
 for feature in data:
  if feature[feature_index]==value:
   refeature=feature[:feature_index]
   refeature.extend(feature[feature_index+1:])
   data_split.append(refeature)
 return data_split
def choose_best_to_split(data):
 
 '''
 根据每个特征的信息增益,选择最大的划分数据集的索引特征
 '''
 
 count_feature=len(data[0])-1#特征个数4
 #print(count_feature)#4
 entropy=cal_entropy(data)#原数据总的信息熵
 #print(entropy)#0.9402859586706309
 
 max_info_gain=0.0#信息增益最大
 split_fea_index = -1#信息增益最大,对应的索引号
 
 for i in range(count_feature):
  
  feature_list=[fe_index[i] for fe_index in data]#获取该列所有特征值
  #######################################
  '''
  print('feature_list')
  ['青少年', '青少年', '中年', '老年', '老年', '老年', '中年', '青少年', '青少年', '老年',
  '青少年', '中年', '中年', '老年']
  0.3467680694480959 #对应上篇博客中的公式 =(1)*5/14
  0.3467680694480959
  0.6935361388961918
  '''
  # print(feature_list)
  unqval=set(feature_list)#去除重复
  pro_entropy=0.0#特征的熵
  for value in unqval:#遍历改特征下的所有属性
   sub_data=split_data(data,i,value)
   pro=len(sub_data)/float(len(data))
   pro_entropy+=pro*cal_entropy(sub_data)
   #print(pro_entropy)
   
  info_gain=entropy-pro_entropy
  if(info_gain>max_info_gain):
   max_info_gain=info_gain
   split_fea_index=i
 return split_fea_index
  
  
##################################################
def most_occur_label(labels):
 #sorted_label_count[0][0] 次数最多的类标签
 label_count={}
 for label in labels:
  if label not in label_count.keys():
   label_count[label]=0
  else:
   label_count[label]+=1
  sorted_label_count = sorted(label_count.items(),key = operator.itemgetter(1),reverse = true)
 return sorted_label_count[0][0]
def build_decesion_tree(dataset,featnames):
 '''
 字典的键存放节点信息,分支及叶子节点存放值
 '''
 featname = featnames[:]    ################
 classlist = [featvec[-1] for featvec in dataset] #此节点的分类情况
 if classlist.count(classlist[0]) == len(classlist): #全部属于一类
  return classlist[0]
 if len(dataset[0]) == 1:   #分完了,没有属性了
  return vote(classlist)  #少数服从多数
 # 选择一个最优特征进行划分
 bestfeat = choose_best_to_split(dataset)
 bestfeatname = featname[bestfeat]
 del(featname[bestfeat])  #防止下标不准
 decisiontree = {bestfeatname:{}}
 # 创建分支,先找出所有属性值,即分支数
 allvalue = [vec[bestfeat] for vec in dataset]
 specvalue = sorted(list(set(allvalue))) #使有一定顺序
 for v in specvalue:
  copyfeatname = featname[:]
  decisiontree[bestfeatname][v] = build_decesion_tree(split_data(dataset,bestfeat,v),copyfeatname)
 return decisiontree

绘制可视化图的代码如下:

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
def getnumleafs(mytree):
 '计算决策树的叶子数'
 
 # 叶子数
 numleafs = 0
 # 节点信息
 sides = list(mytree.keys())
 firststr =sides[0]
 # 分支信息
 seconddict = mytree[firststr]
 
 for key in seconddict.keys(): # 遍历所有分支
  # 子树分支则递归计算
  if type(seconddict[key]).__name__=='dict':
   numleafs += getnumleafs(seconddict[key])
  # 叶子分支则叶子数+1
  else: numleafs +=1
  
 return numleafs
 
 
def gettreedepth(mytree):
 '计算决策树的深度'
 
 # 最大深度
 maxdepth = 0
 # 节点信息
 sides = list(mytree.keys())
 firststr =sides[0]
 # 分支信息
 seconddict = mytree[firststr]
 
 for key in seconddict.keys(): # 遍历所有分支
  # 子树分支则递归计算
  if type(seconddict[key]).__name__=='dict':
   thisdepth = 1 + gettreedepth(seconddict[key])
  # 叶子分支则叶子数+1
  else: thisdepth = 1
  
  # 更新最大深度
  if thisdepth > maxdepth: maxdepth = thisdepth
  
 return maxdepth
 
import matplotlib.pyplot as plt
 
decisionnode = dict(boxstyle="sawtooth", fc="0.8")
leafnode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
 
# ==================================================
# 输入:
#  nodetxt:  终端节点显示内容
#  centerpt: 终端节点坐标
#  parentpt: 起始节点坐标
#  nodetype: 终端节点样式
# 输出:
#  在图形界面中显示输入参数指定样式的线段(终端带节点)
# ==================================================
def plotnode(nodetxt, centerpt, parentpt, nodetype):
 '画线(末端带一个点)'
  
 createplot.ax1.annotate(nodetxt, xy=parentpt, xycoords='axes fraction', xytext=centerpt, textcoords='axes fraction', va="center", ha="center", bbox=nodetype, arrowprops=arrow_args )
 
# =================================================================
# 输入:
#  cntrpt:  终端节点坐标
#  parentpt: 起始节点坐标
#  txtstring: 待显示文本内容
# 输出:
#  在图形界面指定位置(cntrpt和parentpt中间)显示文本内容(txtstring)
# =================================================================
def plotmidtext(cntrpt, parentpt, txtstring):
 '在指定位置添加文本'
 
 # 中间位置坐标
 xmid = (parentpt[0]-cntrpt[0])/2.0 + cntrpt[0]
 ymid = (parentpt[1]-cntrpt[1])/2.0 + cntrpt[1]
 
 createplot.ax1.text(xmid, ymid, txtstring, va="center", ha="center", rotation=30)
 
# ===================================
# 输入:
#  mytree: 决策树
#  parentpt: 根节点坐标
#  nodetxt: 根节点坐标信息
# 输出:
#  在图形界面绘制决策树
# ===================================
def plottree(mytree, parentpt, nodetxt):
 '绘制决策树'
 
 # 当前树的叶子数
 numleafs = getnumleafs(mytree)
 # 当前树的节点信息
 sides = list(mytree.keys())
 firststr =sides[0]
 
 # 定位第一棵子树的位置(这是蛋疼的一部分)
 cntrpt = (plottree.xoff + (1.0 + float(numleafs))/2.0/plottree.totalw, plottree.yoff)
 
 # 绘制当前节点到子树节点(含子树节点)的信息
 plotmidtext(cntrpt, parentpt, nodetxt)
 plotnode(firststr, cntrpt, parentpt, decisionnode)
 
 # 获取子树信息
 seconddict = mytree[firststr]
 # 开始绘制子树,纵坐标-1。 
 plottree.yoff = plottree.yoff - 1.0/plottree.totald
  
 for key in seconddict.keys(): # 遍历所有分支
  # 子树分支则递归
  if type(seconddict[key]).__name__=='dict':
   plottree(seconddict[key],cntrpt,str(key))
  # 叶子分支则直接绘制
  else:
   plottree.xoff = plottree.xoff + 1.0/plottree.totalw
   plotnode(seconddict[key], (plottree.xoff, plottree.yoff), cntrpt, leafnode)
   plotmidtext((plottree.xoff, plottree.yoff), cntrpt, str(key))
  
 # 子树绘制完毕,纵坐标+1。
 plottree.yoff = plottree.yoff + 1.0/plottree.totald
 
# ==============================
# 输入:
#  mytree: 决策树
# 输出:
#  在图形界面显示决策树
# ==============================
def createplot(intree):
 '显示决策树'
 
 # 创建新的图像并清空 - 无横纵坐标
 fig = plt.figure(1, facecolor='white')
 fig.clf()
 axprops = dict(xticks=[], yticks=[])
 createplot.ax1 = plt.subplot(111, frameon=false, **axprops)
 
 # 树的总宽度 高度
 plottree.totalw = float(getnumleafs(intree))
 plottree.totald = float(gettreedepth(intree))
 
 # 当前绘制节点的坐标
 plottree.xoff = -0.5/plottree.totalw;
 plottree.yoff = 1.0;
 
 # 绘制决策树
 plottree(intree, (0.5,1.0), '')
 
 plt.show()
 
if __name__ == '__main__':
 data,features=load_data()
 split_fea_index=choose_best_to_split(data)
 newtree=build_decesion_tree(data,features)
 print(newtree)
 createplot(newtree)
 '''
 {'age': {'old_aged': {'level': {'same': 'yes', 'good': 'no'}}, 'teenager': {'student': {'no': 'no', 'yes': 'yes'}}, 'middle_aged': 'yes'}}
 '''

结果如下:

python实现决策树分类

怎么用决策树分类,将会在下一章

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持服务器之家。

原文链接:https://blog.csdn.net/momaojia/article/details/73744456

标签:

相关文章

热门资讯

2020微信伤感网名听哭了 让对方看到心疼的伤感网名大全
2020微信伤感网名听哭了 让对方看到心疼的伤感网名大全 2019-12-26
yue是什么意思 网络流行语yue了是什么梗
yue是什么意思 网络流行语yue了是什么梗 2020-10-11
背刺什么意思 网络词语背刺是什么梗
背刺什么意思 网络词语背刺是什么梗 2020-05-22
Intellij idea2020永久破解,亲测可用!!!
Intellij idea2020永久破解,亲测可用!!! 2020-07-29
苹果12mini价格表官网报价 iPhone12mini全版本价格汇总
苹果12mini价格表官网报价 iPhone12mini全版本价格汇总 2020-11-13
返回顶部