服务器之家

服务器之家 > 正文

Python实现的KMeans聚类算法实例分析

时间:2021-05-09 00:58     来源/作者:njulpy

本文实例讲述了python实现的kmeans聚类算法。分享给大家供大家参考,具体如下:

菜鸟一枚,编程初学者,最近想使用python3实现几个简单的机器学习分析方法,记录一下自己的学习过程。

关于kmeans算法本身就不做介绍了,下面记录一下自己遇到的问题。

一 、关于初始聚类中心的选取

初始聚类中心的选择一般有:

(1)随机选取

(2)随机选取样本中一个点作为中心点,在通过这个点选取距离其较大的点作为第二个中心点,以此类推。

(3)使用层次聚类等算法更新出初始聚类中心

我一开始是使用numpy随机产生k个聚类中心

?
1
center = np.random.randn(k,n)

但是发现聚类的时候迭代几次以后聚类中心会出现nan,有点搞不清楚怎么回事

所以我分别尝试了:

(1)选择数据集的前k个样本做初始中心点

(2)选择随机k个样本点作为初始聚类中心

发现两者都可以完成聚类,我是用的是iris.csv数据集,在选择前k个样本点做数据集时,迭代次数是固定的,选择随机k个点时,迭代次数和随机种子的选取有关,而且聚类效果也不同,有的随机种子聚类快且好,有的慢且差。

?
1
2
3
4
5
6
7
8
9
def initcenter(k,m,x_train):
  #center = np.random.randn(k,n)
  #center = np.array(x_train.iloc[0:k,:]) #取数据集中前k个点作为初始中心
  center = np.zeros([k,n])         #从样本中随机取k个点做初始聚类中心
  np.random.seed(5)            #设置随机数种子
  for i in range(k):
    x = np.random.randint(m)
    center[i] = np.array(x_train.iloc[x])
  return center

二 、关于类间距离的选取

为了简单,我直接采用了欧氏距离,目前还没有尝试其他的距离算法。

?
1
2
3
4
5
6
7
8
9
10
def getdistense(x_train, k, m, center):
  distence=[]
  for j in range(k):
    for i in range(m):
      x = np.array(x_train.iloc[i, :])
      a = x.t - center[j]
      dist = np.sqrt(np.sum(np.square(a))) # dist = np.linalg.norm(x.t - center)
      distence.append(dist)
  dis_array = np.array(distence).reshape(k,m)
  return dis_array

三 、关于终止聚类条件的选取

关于聚类的终止条件有很多选择方法:

(1)迭代一定次数

(2)聚类中心的更新小于某个给定的阈值

(3)类中的样本不再变化

我用的是前两种方法,第一种很简单,但是聚类效果不好控制,针对不同数据集,稳健性也不够。第二种比较合适,稳健性也强。第三种方法我还没有尝试,以后可以试着用一下,可能聚类精度会更高一点。

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def kmcluster(x_train,k,n,m,threshold):
  global axis_x, axis_y
  center = initcenter(k,m,x_train)
  initcenter = center
  centerchanged = true
  t=0
  while centerchanged:
    dis_array = getdistense(x_train, k, m, center)
    center ,axis_x,axis_y,axis_z= getnewcenter(x_train,k,n,dis_array)
    err = np.linalg.norm(initcenter[-k:] - center)
    print(err)
    t+=1
    plt.figure(1)
    p=plt.subplot(3, 3, t)
    p1,p2,p3 = plt.scatter(axis_x[0], axis_y[0], c='r'),plt.scatter(axis_x[1], axis_y[1], c='g'),plt.scatter(axis_x[2], axis_y[2], c='b')
    plt.legend(handles=[p1, p2, p3], labels=['0', '1', '2'], loc='best')
    p.set_title('iteration'+ str(t))
    if err < threshold:
      centerchanged = false
    else:
      initcenter = np.concatenate((initcenter, center), axis=0)
  plt.show()
  return center, axis_x, axis_y,axis_z, initcenter

err是本次聚类中心点和上次聚类中心点之间的欧氏距离。

threshold是人为设定的终止聚类的阈值,我个人一般设置为0.1或者0.01。

为了将每次迭代产生的类别显示出来我修改了上述代码,使用matplotlib展示每次迭代的散点图。

下面附上我测试数据时的图,子图设置的个数要根据迭代次数来定。

Python实现的KMeans聚类算法实例分析

我测试了几个数据集,聚类的精度还是可以的。

使用iris数据集分析的结果为:

err of iteration 1 is 3.11443180281
err of iteration 2 is 1.27568813621
err of iteration 3 is 0.198909381512
err of iteration 4 is 0.0
final cluster center is  [[ 6.85        3.07368421  5.74210526  2.07105263]
 [ 5.9016129   2.7483871   4.39354839  1.43387097]
 [ 5.006       3.428       1.462       0.246     ]]

最后附上全部代码,错误之处还请多多批评,谢谢。

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#encoding:utf-8
"""
  author:   njulpy
  version:   1.0
  data:   2018/04/11
  project: using python to implement kmeans clustering algorithm
"""
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
from sklearn.cluster import kmeans
def initcenter(k,m,x_train):
  #center = np.random.randn(k,n)
  #center = np.array(x_train.iloc[0:k,:]) #取数据集中前k个点作为初始中心
  center = np.zeros([k,n])         #从样本中随机取k个点做初始聚类中心
  np.random.seed(15)            #设置随机数种子
  for i in range(k):
    x = np.random.randint(m)
    center[i] = np.array(x_train.iloc[x])
  return center
def getdistense(x_train, k, m, center):
  distence=[]
  for j in range(k):
    for i in range(m):
      x = np.array(x_train.iloc[i, :])
      a = x.t - center[j]
      dist = np.sqrt(np.sum(np.square(a))) # dist = np.linalg.norm(x.t - center)
      distence.append(dist)
  dis_array = np.array(distence).reshape(k,m)
  return dis_array
def getnewcenter(x_train,k,n, dis_array):
  cen = []
  axisx ,axisy,axisz= [],[],[]
  cls = np.argmin(dis_array, axis=0)
  for i in range(k):
    train_i=x_train.loc[cls == i]
    xx,yy,zz = list(train_i.iloc[:,1]),list(train_i.iloc[:,2]),list(train_i.iloc[:,3])
    axisx.append(xx)
    axisy.append(yy)
    axisz.append(zz)
    meanc = np.mean(train_i,axis=0)
    cen.append(meanc)
  newcent = np.array(cen).reshape(k,n)
  newcent=np.nan_to_num(newcent)
  return newcent,axisx,axisy,axisz
def kmcluster(x_train,k,n,m,threshold):
  global axis_x, axis_y
  center = initcenter(k,m,x_train)
  initcenter = center
  centerchanged = true
  t=0
  while centerchanged:
    dis_array = getdistense(x_train, k, m, center)
    center ,axis_x,axis_y,axis_z= getnewcenter(x_train,k,n,dis_array)
    err = np.linalg.norm(initcenter[-k:] - center)
    t+=1
    print('err of iteration '+str(t),'is',err)
    plt.figure(1)
    p=plt.subplot(2, 3, t)
    p1,p2,p3 = plt.scatter(axis_x[0], axis_y[0], c='r'),plt.scatter(axis_x[1], axis_y[1], c='g'),plt.scatter(axis_x[2], axis_y[2], c='b')
    plt.legend(handles=[p1, p2, p3], labels=['0', '1', '2'], loc='best')
    p.set_title('iteration'+ str(t))
    if err < threshold:
      centerchanged = false
    else:
      initcenter = np.concatenate((initcenter, center), axis=0)
  plt.show()
  return center, axis_x, axis_y,axis_z, initcenter
if __name__=="__main__":
  #x=pd.read_csv("8.advertising.csv")  # 两组测试数据
  #x=pd.read_table("14.bipartition.txt")
  x=pd.read_csv("iris.csv")
  x_train=x.iloc[:,1:5]
  m,n = np.shape(x_train)
  k = 3
  threshold = 0.1
  km,ax,ay,az,ddd = kmcluster(x_train, k, n, m, threshold)
  print('final cluster center is ', km)
  #2-dplot
  plt.figure(2)
  plt.scatter(km[0,1],km[0,2],c = 'r',s = 550,marker='x')
  plt.scatter(km[1,1],km[1,2],c = 'g',s = 550,marker='x')
  plt.scatter(km[2,1],km[2,2],c = 'b',s = 550,marker='x')
  p1, p2, p3 = plt.scatter(axis_x[0], axis_y[0], c='r'), plt.scatter(axis_x[1], axis_y[1], c='g'), plt.scatter(axis_x[2], axis_y[2], c='b')
  plt.legend(handles=[p1, p2, p3], labels=['0', '1', '2'], loc='best')
  plt.title('2-d scatter')
  plt.show()
  #3-dplot
  plt.figure(3)
  treed = plt.subplot(111, projection='3d')
  treed.scatter(ax[0],ay[0],az[0],c='r')
  treed.scatter(ax[1],ay[1],az[1],c='g')
  treed.scatter(ax[2],ay[2],az[2],c='b')
  treed.set_zlabel('z') # 坐标轴
  treed.set_ylabel('y')
  treed.set_xlabel('x')
  treed.set_title('3-d scatter')
  plt.show()

Python实现的KMeans聚类算法实例分析

Python实现的KMeans聚类算法实例分析

附:上述示例中的iris.csv文件。

希望本文所述对大家python程序设计有所帮助。

原文链接:https://blog.csdn.net/njulpy/article/details/79895750

相关文章

热门资讯

2020微信伤感网名听哭了 让对方看到心疼的伤感网名大全
2020微信伤感网名听哭了 让对方看到心疼的伤感网名大全 2019-12-26
yue是什么意思 网络流行语yue了是什么梗
yue是什么意思 网络流行语yue了是什么梗 2020-10-11
背刺什么意思 网络词语背刺是什么梗
背刺什么意思 网络词语背刺是什么梗 2020-05-22
苹果12mini价格表官网报价 iPhone12mini全版本价格汇总
苹果12mini价格表官网报价 iPhone12mini全版本价格汇总 2020-11-13
2021德云社封箱演出完整版 2021年德云社封箱演出在线看
2021德云社封箱演出完整版 2021年德云社封箱演出在线看 2021-03-15
返回顶部