一、k均值聚类的简单介绍
假设样本分为c类,每个类均存在一个中心点,通过随机生成c个中心点进行迭代,计算每个样本点到类中心的距离(可以自定义、常用的是欧式距离)
将该样本点归入到最短距离所在的类,重新计算聚类中心,进行下次的重新划分样本,最终类中心不改变时,聚类完成
二、伪代码
三、python代码实现
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
|
#!/usr/bin/env python # coding=utf-8 import numpy as np import random import matplotlib.pyplot as plt #data:numpy.array dataset #k the number of cluster def k_means(data,k): #random generate cluster_center sample_num = data.shape[ 0 ] center_index = random.sample( range (sample_num),k) cluster_cen = data[center_index,:] is_change = 1 cat = np.zeros(sample_num) while is_change: is_change = 0 for i in range (sample_num): min_distance = 100000 min_index = 0 for j in range (k): sub_data = data[i,:] - cluster_cen[j,:] distance = np.inner(sub_data,sub_data) if distance<min_distance: min_distance = distance min_index = j + 1 if cat[i]! = min_index: is_change = 1 cat[i] = min_index for j in range (k): cluster_cen[j] = np.mean(data[cat = = (j + 1 )],axis = 0 ) return cat,cluster_cen if __name__ = = '__main__' : #generate data cov = [[ 1 , 0 ],[ 0 , 1 ]] mean1 = [ 1 , - 1 ] x1 = np.random.multivariate_normal(mean1,cov, 200 ) mean2 = [ 5.5 , - 4.5 ] x2 = np.random.multivariate_normal(mean2,cov, 200 ) mean3 = [ 1 , 4 ] x3 = np.random.multivariate_normal(mean3,cov, 200 ) mean4 = [ 6 , 4.5 ] x4 = np.random.multivariate_normal(mean4,cov, 200 ) mean5 = [ 9 , 0.0 ] x5 = np.random.multivariate_normal(mean5,cov, 200 ) X = np.vstack((x1,x2,x3,x4,x5)) #data distribution fig1 = plt.figure( 1 ) p1 = plt.scatter(x1[:, 0 ],x1[:, 1 ],marker = 'o' ,color = 'r' ,label = 'x1' ) p2 = plt.scatter(x2[:, 0 ],x2[:, 1 ],marker = '+' ,color = 'm' ,label = 'x2' ) p3 = plt.scatter(x3[:, 0 ],x3[:, 1 ],marker = 'x' ,color = 'b' ,label = 'x3' ) p4 = plt.scatter(x4[:, 0 ],x4[:, 1 ],marker = '*' ,color = 'g' ,label = 'x4' ) p5 = plt.scatter(x5[:, 0 ],x4[:, 1 ],marker = '+' ,color = 'y' ,label = 'x5' ) plt.title( 'original data' ) plt.legend(loc = 'upper right' ) cat,cluster_cen = k_means(X, 5 ) print 'the number of cluster 1:' , sum (cat = = 1 ) print 'the number of cluster 2:' , sum (cat = = 2 ) print 'the number of cluster 3:' , sum (cat = = 3 ) print 'the number of cluster 4:' , sum (cat = = 4 ) print 'the number of cluster 5:' , sum (cat = = 5 ) fig2 = plt.figure( 2 ) for i,m,lo,label in zip ( range ( 5 ),[ 'o' , '+' , 'x' , '*' , '+' ],[ 'r' , 'm' , 'b' , 'g' , 'y' ],[ 'x1' , 'x2' , 'x3' , 'x4' , 'x5' ]): p = plt.scatter(X[cat = = (i + 1 ), 0 ],X[cat = = (i + 1 ), 1 ],marker = m,color = lo,label = label) plt.legend(loc = 'upper right' ) plt.title( 'the clustering result' ) plt.show() |
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持服务器之家。
原文链接:http://blog.csdn.net/Jason____zhou/article/details/50283035