本文实例讲述了python实现的kmeans聚类算法。分享给大家供大家参考,具体如下:
菜鸟一枚,编程初学者,最近想使用python3实现几个简单的机器学习分析方法,记录一下自己的学习过程。
关于kmeans算法本身就不做介绍了,下面记录一下自己遇到的问题。
一 、关于初始聚类中心的选取
初始聚类中心的选择一般有:
(1)随机选取
(2)随机选取样本中一个点作为中心点,在通过这个点选取距离其较大的点作为第二个中心点,以此类推。
(3)使用层次聚类等算法更新出初始聚类中心
我一开始是使用numpy随机产生k个聚类中心
1
|
center = np.random.randn(k,n) |
但是发现聚类的时候迭代几次以后聚类中心会出现nan,有点搞不清楚怎么回事
所以我分别尝试了:
(1)选择数据集的前k个样本做初始中心点
(2)选择随机k个样本点作为初始聚类中心
发现两者都可以完成聚类,我是用的是iris.csv数据集,在选择前k个样本点做数据集时,迭代次数是固定的,选择随机k个点时,迭代次数和随机种子的选取有关,而且聚类效果也不同,有的随机种子聚类快且好,有的慢且差。
1
2
3
4
5
6
7
8
9
|
def initcenter(k,m,x_train): #center = np.random.randn(k,n) #center = np.array(x_train.iloc[0:k,:]) #取数据集中前k个点作为初始中心 center = np.zeros([k,n]) #从样本中随机取k个点做初始聚类中心 np.random.seed( 5 ) #设置随机数种子 for i in range (k): x = np.random.randint(m) center[i] = np.array(x_train.iloc[x]) return center |
二 、关于类间距离的选取
为了简单,我直接采用了欧氏距离,目前还没有尝试其他的距离算法。
1
2
3
4
5
6
7
8
9
10
|
def getdistense(x_train, k, m, center): distence = [] for j in range (k): for i in range (m): x = np.array(x_train.iloc[i, :]) a = x.t - center[j] dist = np.sqrt(np. sum (np.square(a))) # dist = np.linalg.norm(x.t - center) distence.append(dist) dis_array = np.array(distence).reshape(k,m) return dis_array |
三 、关于终止聚类条件的选取
关于聚类的终止条件有很多选择方法:
(1)迭代一定次数
(2)聚类中心的更新小于某个给定的阈值
(3)类中的样本不再变化
我用的是前两种方法,第一种很简单,但是聚类效果不好控制,针对不同数据集,稳健性也不够。第二种比较合适,稳健性也强。第三种方法我还没有尝试,以后可以试着用一下,可能聚类精度会更高一点。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
|
def kmcluster(x_train,k,n,m,threshold): global axis_x, axis_y center = initcenter(k,m,x_train) initcenter = center centerchanged = true t = 0 while centerchanged: dis_array = getdistense(x_train, k, m, center) center ,axis_x,axis_y,axis_z = getnewcenter(x_train,k,n,dis_array) err = np.linalg.norm(initcenter[ - k:] - center) print (err) t + = 1 plt.figure( 1 ) p = plt.subplot( 3 , 3 , t) p1,p2,p3 = plt.scatter(axis_x[ 0 ], axis_y[ 0 ], c = 'r' ),plt.scatter(axis_x[ 1 ], axis_y[ 1 ], c = 'g' ),plt.scatter(axis_x[ 2 ], axis_y[ 2 ], c = 'b' ) plt.legend(handles = [p1, p2, p3], labels = [ '0' , '1' , '2' ], loc = 'best' ) p.set_title( 'iteration' + str (t)) if err < threshold: centerchanged = false else : initcenter = np.concatenate((initcenter, center), axis = 0 ) plt.show() return center, axis_x, axis_y,axis_z, initcenter |
err是本次聚类中心点和上次聚类中心点之间的欧氏距离。
threshold是人为设定的终止聚类的阈值,我个人一般设置为0.1或者0.01。
为了将每次迭代产生的类别显示出来我修改了上述代码,使用matplotlib展示每次迭代的散点图。
下面附上我测试数据时的图,子图设置的个数要根据迭代次数来定。
我测试了几个数据集,聚类的精度还是可以的。
使用iris数据集分析的结果为:
err of iteration 1 is 3.11443180281
err of iteration 2 is 1.27568813621
err of iteration 3 is 0.198909381512
err of iteration 4 is 0.0
final cluster center is [[ 6.85 3.07368421 5.74210526 2.07105263]
[ 5.9016129 2.7483871 4.39354839 1.43387097]
[ 5.006 3.428 1.462 0.246 ]]
最后附上全部代码,错误之处还请多多批评,谢谢。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
|
#encoding:utf-8 """ author: njulpy version: 1.0 data: 2018/04/11 project: using python to implement kmeans clustering algorithm """ import numpy as np import pandas as pd import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import axes3d from sklearn.cluster import kmeans def initcenter(k,m,x_train): #center = np.random.randn(k,n) #center = np.array(x_train.iloc[0:k,:]) #取数据集中前k个点作为初始中心 center = np.zeros([k,n]) #从样本中随机取k个点做初始聚类中心 np.random.seed( 15 ) #设置随机数种子 for i in range (k): x = np.random.randint(m) center[i] = np.array(x_train.iloc[x]) return center def getdistense(x_train, k, m, center): distence = [] for j in range (k): for i in range (m): x = np.array(x_train.iloc[i, :]) a = x.t - center[j] dist = np.sqrt(np. sum (np.square(a))) # dist = np.linalg.norm(x.t - center) distence.append(dist) dis_array = np.array(distence).reshape(k,m) return dis_array def getnewcenter(x_train,k,n, dis_array): cen = [] axisx ,axisy,axisz = [],[],[] cls = np.argmin(dis_array, axis = 0 ) for i in range (k): train_i = x_train.loc[ cls = = i] xx,yy,zz = list (train_i.iloc[:, 1 ]), list (train_i.iloc[:, 2 ]), list (train_i.iloc[:, 3 ]) axisx.append(xx) axisy.append(yy) axisz.append(zz) meanc = np.mean(train_i,axis = 0 ) cen.append(meanc) newcent = np.array(cen).reshape(k,n) newcent = np.nan_to_num(newcent) return newcent,axisx,axisy,axisz def kmcluster(x_train,k,n,m,threshold): global axis_x, axis_y center = initcenter(k,m,x_train) initcenter = center centerchanged = true t = 0 while centerchanged: dis_array = getdistense(x_train, k, m, center) center ,axis_x,axis_y,axis_z = getnewcenter(x_train,k,n,dis_array) err = np.linalg.norm(initcenter[ - k:] - center) t + = 1 print ( 'err of iteration ' + str (t), 'is' ,err) plt.figure( 1 ) p = plt.subplot( 2 , 3 , t) p1,p2,p3 = plt.scatter(axis_x[ 0 ], axis_y[ 0 ], c = 'r' ),plt.scatter(axis_x[ 1 ], axis_y[ 1 ], c = 'g' ),plt.scatter(axis_x[ 2 ], axis_y[ 2 ], c = 'b' ) plt.legend(handles = [p1, p2, p3], labels = [ '0' , '1' , '2' ], loc = 'best' ) p.set_title( 'iteration' + str (t)) if err < threshold: centerchanged = false else : initcenter = np.concatenate((initcenter, center), axis = 0 ) plt.show() return center, axis_x, axis_y,axis_z, initcenter if __name__ = = "__main__" : #x=pd.read_csv("8.advertising.csv") # 两组测试数据 #x=pd.read_table("14.bipartition.txt") x = pd.read_csv( "iris.csv" ) x_train = x.iloc[:, 1 : 5 ] m,n = np.shape(x_train) k = 3 threshold = 0.1 km,ax,ay,az,ddd = kmcluster(x_train, k, n, m, threshold) print ( 'final cluster center is ' , km) #2-dplot plt.figure( 2 ) plt.scatter(km[ 0 , 1 ],km[ 0 , 2 ],c = 'r' ,s = 550 ,marker = 'x' ) plt.scatter(km[ 1 , 1 ],km[ 1 , 2 ],c = 'g' ,s = 550 ,marker = 'x' ) plt.scatter(km[ 2 , 1 ],km[ 2 , 2 ],c = 'b' ,s = 550 ,marker = 'x' ) p1, p2, p3 = plt.scatter(axis_x[ 0 ], axis_y[ 0 ], c = 'r' ), plt.scatter(axis_x[ 1 ], axis_y[ 1 ], c = 'g' ), plt.scatter(axis_x[ 2 ], axis_y[ 2 ], c = 'b' ) plt.legend(handles = [p1, p2, p3], labels = [ '0' , '1' , '2' ], loc = 'best' ) plt.title( '2-d scatter' ) plt.show() #3-dplot plt.figure( 3 ) treed = plt.subplot( 111 , projection = '3d' ) treed.scatter(ax[ 0 ],ay[ 0 ],az[ 0 ],c = 'r' ) treed.scatter(ax[ 1 ],ay[ 1 ],az[ 1 ],c = 'g' ) treed.scatter(ax[ 2 ],ay[ 2 ],az[ 2 ],c = 'b' ) treed.set_zlabel( 'z' ) # 坐标轴 treed.set_ylabel( 'y' ) treed.set_xlabel( 'x' ) treed.set_title( '3-d scatter' ) plt.show() |
附:上述示例中的iris.csv文件。
希望本文所述对大家python程序设计有所帮助。
原文链接:https://blog.csdn.net/njulpy/article/details/79895750