1. 程式人生 > 實用技巧 >MATLAB k-means聚類

MATLAB k-means聚類

聚類演算法,不是分類演算法。

分類演算法是給一個數據,然後判斷這個資料屬於已分好的類中的具體哪一類。

聚類演算法是給一大堆原始資料,然後通過演算法將其中具有相似特徵的資料聚為一類。

這裡的k-means聚類,是事先給出原始資料所含的類數,然後將含有相似特徵的資料聚為一個類中。

所有資料中還是Andrew Ng介紹的明白。

首先給出原始資料{x1,x2,...,xn},這些資料沒有被標記的。

初始化k個隨機資料u1,u2,...,uk。這些xn和uk都是向量。

根據下面兩個公式迭代就能求出最終所有的u,這些u就是最終所有類的中心位置。

公式一:

意思就是求出所有資料和初始化的隨機資料的距離,然後找出距離每個初始資料最近的資料。

公式二:

意思就是求出所有和這個初始資料最近原始資料的距離的均值。

然後不斷迭代兩個公式,直到所有的u都不怎麼變化了,就算完成了。

先看看一些結果:

用三個二維高斯分佈資料畫出的圖:

通過對沒有標記的原始資料進行kmeans聚類得到的分類,十字是最終迭代位置:

下面是Matlab程式碼,這裡我把測試資料改為了三維了,函式是可以處理各種維度的。

main.m

 1 clear all;
 2 close all;
 3 clc;
 4 
 5 %第一類資料
 6 mu1=[0 0 0];  %均值
 7 S1=[0.3 0 0;0 0.35 0;0 0 0.3];  %協方差
 8 data1=mvnrnd(mu1,S1,100
); %產生高斯分佈資料 9 10 %%第二類資料 11 mu2=[1.25 1.25 1.25]; 12 S2=[0.3 0 0;0 0.35 0;0 0 0.3]; 13 data2=mvnrnd(mu2,S2,100); 14 15 %第三個類資料 16 mu3=[-1.25 1.25 -1.25]; 17 S3=[0.3 0 0;0 0.35 0;0 0 0.3]; 18 data3=mvnrnd(mu3,S3,100); 19 20 %顯示資料 21 plot3(data1(:,1),data1(:,2),data1(:,3),'+'); 22 hold on; 23 plot3(data2(:,1),data2(:,2
),data2(:,3),'r+'); 24 plot3(data3(:,1),data3(:,2),data3(:,3),'g+'); 25 grid on; 26 27 %三類資料合成一個不帶標號的資料類 28 data=[data1;data2;data3]; %這裡的data是不帶標號的 29 30 %k-means聚類 31 [u re]=KMeans(data,3); %最後產生帶標號的資料,標號在所有資料的最後,意思就是資料再加一維度 32 [m n]=size(re); 33 34 %最後顯示聚類後的資料 35 figure; 36 hold on; 37 for i=1:m 38 if re(i,4)==1 39 plot3(re(i,1),re(i,2),re(i,3),'ro'); 40 elseif re(i,4)==2 41 plot3(re(i,1),re(i,2),re(i,3),'go'); 42 else 43 plot3(re(i,1),re(i,2),re(i,3),'bo'); 44 end 45 end 46 grid on;

KMeans.m

 1 %N是資料一共分多少類
 2 %data是輸入的不帶分類標號的資料
 3 %u是每一類的中心
 4 %re是返回的帶分類標號的資料
 5 function [u re]=KMeans(data,N)   
 6     [m n]=size(data);   %m是資料個數,n是資料維數
 7     ma=zeros(n);        %每一維最大的數
 8     mi=zeros(n);        %每一維最小的數
 9     u=zeros(N,n);       %隨機初始化,最終迭代到每一類的中心位置
10     for i=1:n
11        ma(i)=max(data(:,i));    %每一維最大的數
12        mi(i)=min(data(:,i));    %每一維最小的數
13        for j=1:N
14             u(j,i)=ma(i)+(mi(i)-ma(i))*rand();  %隨機初始化,不過還是在每一維[min max]中初始化好些
15        end      
16     end
17    
18     while 1
19         pre_u=u;            %上一次求得的中心位置
20         for i=1:N
21             tmp{i}=[];      % 公式一中的x(i)-uj,為公式一實現做準備
22             for j=1:m
23                 tmp{i}=[tmp{i};data(j,:)-u(i,:)];
24             end
25         end
26         
27         quan=zeros(m,N);
28         for i=1:m        %公式一的實現
29             c=[];
30             for j=1:N
31                 c=[c norm(tmp{j}(i,:))];
32             end
33             [junk index]=min(c);
34             quan(i,index)=norm(tmp{index}(i,:));           
35         end
36         
37         for i=1:N            %公式二的實現
38            for j=1:n
39                 u(i,j)=sum(quan(:,i).*data(:,j))/sum(quan(:,i));
40            end           
41         end
42         
43         if norm(pre_u-u)<0.1  %不斷迭代直到位置不再變化
44             break;
45         end
46     end
47     
48     re=[];
49     for i=1:m
50         tmp=[];
51         for j=1:N
52             tmp=[tmp norm(data(i,:)-u(j,:))];
53         end
54         [junk index]=min(tmp);
55         re=[re;data(i,:) index];
56     end
57     
58 end