1. 程式人生 > >kmeans演算法詳解與spark實戰

kmeans演算法詳解與spark實戰

1.標準kmeans演算法

kmeans演算法是實際中最常用的聚類演算法,沒有之一。kmeans演算法的原理簡單,實現起來不是很複雜,實際中使用的效果一般也不錯,所以深受廣大人民群眾的喜愛。
kmeans演算法的原理介紹方面的paper多如牛毛,而且理解起來確實也不是很複雜,這裡使用wiki上的版本:
已知觀測集(x1,x2,,xn),其中每個觀測都是一個d維實向量,kmeans聚類要把這n個觀測值劃分到k個集合中(kn),使得組內平方和(WCSS within-cluster sum of squares)最小。換句話說,它的目標是找到使得下式滿足的聚類Si

argminSi
=1
k
xSixμi2

其中μiSi中所有點的均值。

標準kmeans演算法的步驟一般如下:
1.先隨機挑選k個初始聚類中心。
2.計算資料集中每個點到每個聚類中心的距離,然後將這個點分配到離該點最近的聚類中心。
3.重新計算每個類中所有點的座標的平均值,並將得到的這個新的點作為新的聚類中心。
重複上面第2、3步,知道聚類中心點不再大範圍移動(精度自己定義)或者迭代的總次數達到最大。

2.標準kmeans演算法的優缺點

標準的kmeans演算法的優缺點都很突出。這裡挑幾個最重要的點總結一下。

主要優點:

1.原理簡單,易於理解。
2.實現簡單
3.計算速度較快
4.聚類效果還不錯。

主要缺點:

1.需要確定k值。
2.對初始中心點的選擇敏感。
3.對異常值敏感,因為異常值很很大程度影響聚類中心的位置。
4.無法增量計算。這點在資料量大的時候尤為突出。

3.spark中對kmeans的優化

作為經典的聚類演算法,一般的機器學習框架裡都實現由kmeans,spark自然也不例外。前面我們已經講了標準kmeans的流程以及優缺點,那麼針對標準kmeans中的不足,spark裡主要做了如下的優化:

1.選擇合適的K值。

k的選擇是kmeans演算法的關鍵。Spark MLlib在KMeansModel裡實現了computeCost方法,這個方法通過計算資料集中所有的點到最近中心點的平方和來衡量聚類的效果。一般來說,同樣的迭代次數,這個cost值越小,說明聚類的效果越好。但在實際使用過程中,必須還要考慮聚類結果的可解釋性,不能一味地選擇cost值最小的那個k。比如我們如果考慮極限情況,如果資料集有n個點,如果令k=n,每個點都是聚類中心,每個類都只有一個點,此時cost值最小為0。但是這樣的聚類結果顯然是沒有實際意義的。

2.選擇合適的初始中心點

大部分迭代演算法都對初始值很敏感,kmeans也是如此。spark MLlib在初始中心點的選擇上,使用了k-means++的演算法。想要詳細瞭解k-means++的同學們,可以參考k-means++在wiki上的介紹:https://en.wikipedia.org/wiki/K-means%2B%2B
kmeans++的基本思想是是初始中心店的相互距離儘可能遠。為了實現這個初衷,採取如下步驟:
1.從初始資料集中隨機選擇一個點作為第一個聚類中心點。
2.計算資料集中所有點到最近一箇中心點的距離D(x)並存在一個數組裡,然後將所有這些距離加起來得到Sum(D(x))。
3.然後再取一個隨機值,用權重的方式計算下一個中心點。具體的實現方法:先取一個在Sum(D(x))範圍內的隨機值,然後領Random -= D(x),直至Random <= 0,此時這個D(x)對應的點為下一個中心點。
4.重複2、3步直到k個聚類中心點被找出。
5.利用找出的k個聚類中心點,執行標準的kmeans演算法。

演算法的關鍵是在第三步。有兩個小點需要說明:
1.不能直接取距離最大的那個點當中心店。因為這個點很可能是離群點。
2.這種取隨機值的方法能保證距離最大的那個點被選中的概率最大。給大家舉個很簡單的例子:假設有四個點A、B、C、D,分別離最近中心的距離D(x)為1、2、3、4,那麼Sum(D(x))=10。然後在[0,10]之間取一隨機數假設為random,然後用random與D(x)依次相減,直至random<0為止。應該不難發現,D被選中的概率最大。

4.spark實戰kmeans演算法

前面講了這麼多理論,照例咱們需要實踐一把。talk is cheap,show me the code!

1.準備資料

將資料下載下來以後檢視一把,第一行相當於是表頭,是對資料的相關說明。將此行去掉,還剩440行。將前400行作為訓練集,後40行作為測試集。

2.將程式碼run起來

import org.apache.spark.{SparkContext, SparkConf}
import org.apache.spark.mllib.clustering.{KMeans, KMeansModel}
import org.apache.spark.mllib.linalg.Vectors

object KmeansTest {
  def main(args: Array[String]) {

    val conf = new
        SparkConf().setAppName("K-Means Clustering").setMaster("spark://your host:7077").setJars(List("your jar file"))
    val sc = new SparkContext(conf)

    val rawTrainingData = sc.textFile("file:///Users/lei.wang/data/data_training")
    val parsedTrainingData =
      rawTrainingData.filter(!isColumnNameLine(_)).map(line => {
        Vectors.dense(line.split(",").map(_.trim).filter(!"".equals(_)).map(_.toDouble))
      }).cache()

    // Cluster the data into two classes using KMeans

    val numClusters = 8
    val numIterations = 30
    val runTimes = 3
    var clusterIndex: Int = 0
    val clusters: KMeansModel =
      KMeans.train(parsedTrainingData, numClusters, numIterations, runTimes)

    println("Cluster Number:" + clusters.clusterCenters.length)

    println("Cluster Centers Information Overview:")
    clusters.clusterCenters.foreach(
      x => {
        println("Center Point of Cluster " + clusterIndex + ":")
        println(x)
        clusterIndex += 1
      })

    //begin to check which cluster each test data belongs to based on the clustering result

    val rawTestData = sc.textFile("file:///Users/lei.wang/data/data_test")
    val parsedTestData = rawTestData.map(line => {
      Vectors.dense(line.split(",").map(_.trim).filter(!"".equals(_)).map(_.toDouble))

    })
    parsedTestData.collect().foreach(testDataLine => {
      val predictedClusterIndex:
      Int = clusters.predict(testDataLine)
      println("The data " + testDataLine.toString + " belongs to cluster " +
        predictedClusterIndex)
    })

    println("Spark MLlib K-means clustering test finished.")
  }

  private def isColumnNameLine(line: String): Boolean = {
    if (line != null && line.contains("Channel")) true
    else false
  }
}

在本地將程式碼跑起來以後,輸出如下:

...
Cluster Number:8
Cluster Centers Information Overview:
Center Point of Cluster 0:
[1.103448275862069,2.5517241379310343,39491.1724137931,4220.6551724137935,5250.172413793103,4478.103448275862,870.9655172413793,2152.8275862068967]
Center Point of Cluster 1:
[2.0,2.4210526315789473,7905.894736842105,20288.052631578947,30969.263157894737,2002.0526315789473,14125.105263157893,3273.4736842105262]
Center Point of Cluster 2:
[1.0,2.5,34782.0,30367.0,16898.0,48701.5,755.5,26776.0]
Center Point of Cluster 3:
[1.2190476190476192,2.5142857142857147,17898.97142857143,3221.7904761904765,4525.866666666667,3639.419047619048,1061.152380952381,1609.9047619047622]
Center Point of Cluster 4:
[1.8987341772151898,2.481012658227848,4380.5822784810125,9389.151898734177,14524.556962025315,1508.4556962025317,6457.683544303797,1481.1772151898733]
Center Point of Cluster 5:
[1.0817610062893082,2.4716981132075473,5098.270440251573,2804.295597484277,3309.0943396226417,2416.37106918239,901.1886792452831,803.0062893081762]
Center Point of Cluster 6:
[1.0,3.0,85779.66666666666,12503.666666666666,12619.666666666666,13991.666666666666,2159.0,3958.0]
Center Point of Cluster 7:
[2.0,3.0,29862.5,53080.75,60015.75,3262.25,27942.25,3082.25]
...

此部分內容為聚類中心點相關資訊,我們將k設為8,所以一共有8箇中心點。

...
The data [1.0,3.0,4446.0,906.0,1238.0,3576.0,153.0,1014.0] belongs to cluster 5
The data [1.0,3.0,27167.0,2801.0,2128.0,13223.0,92.0,1902.0] belongs to cluster 3
The data [1.0,3.0,26539.0,4753.0,5091.0,220.0,10.0,340.0] belongs to cluster 3
The data [1.0,3.0,25606.0,11006.0,4604.0,127.0,632.0,288.0] belongs to cluster 3
The data [1.0,3.0,18073.0,4613.0,3444.0,4324.0,914.0,715.0] belongs to cluster 3
The data [1.0,3.0,6884.0,1046.0,1167.0,2069.0,593.0,378.0] belongs to cluster 5
The data [1.0,3.0,25066.0,5010.0,5026.0,9806.0,1092.0,960.0] belongs to cluster 3
The data [2.0,3.0,7362.0,12844.0,18683.0,2854.0,7883.0,553.0] belongs to cluster 4
The data [2.0,3.0,8257.0,3880.0,6407.0,1646.0,2730.0,344.0] belongs to cluster 5
The data [1.0,3.0,8708.0,3634.0,6100.0,2349.0,2123.0,5137.0] belongs to cluster 5
The data [1.0,3.0,6633.0,2096.0,4563.0,1389.0,1860.0,1892.0] belongs to cluster 5
The data [1.0,3.0,2126.0,3289.0,3281.0,1535.0,235.0,4365.0] belongs to cluster 5
The data [1.0,3.0,97.0,3605.0,12400.0,98.0,2970.0,62.0] belongs to cluster 4
The data [1.0,3.0,4983.0,4859.0,6633.0,17866.0,912.0,2435.0] belongs to cluster 5
The data [1.0,3.0,5969.0,1990.0,3417.0,5679.0,1135.0,290.0] belongs to cluster 5
The data [2.0,3.0,7842.0,6046.0,8552.0,1691.0,3540.0,1874.0] belongs to cluster 5
The data [2.0,3.0,4389.0,10940.0,10908.0,848.0,6728.0,993.0] belongs to cluster 4
The data [1.0,3.0,5065.0,5499.0,11055.0,364.0,3485.0,1063.0] belongs to cluster 4
The data [2.0,3.0,660.0,8494.0,18622.0,133.0,6740.0,776.0] belongs to cluster 4
The data [1.0,3.0,8861.0,3783.0,2223.0,633.0,1580.0,1521.0] belongs to cluster 5
The data [1.0,3.0,4456.0,5266.0,13227.0,25.0,6818.0,1393.0] belongs to cluster 4
The data [2.0,3.0,17063.0,4847.0,9053.0,1031.0,3415.0,1784.0] belongs to cluster 3
The data [1.0,3.0,26400.0,1377.0,4172.0,830.0,948.0,1218.0] belongs to cluster 3
The data [2.0,3.0,17565.0,3686.0,4657.0,1059.0,1803.0,668.0] belongs to cluster 3
The data [2.0,3.0,16980.0,2884.0,12232.0,874.0,3213.0,249.0] belongs to cluster 3
The data [1.0,3.0,11243.0,2408.0,2593.0,15348.0,108.0,1886.0] belongs to cluster 3
The data [1.0,3.0,13134.0,9347.0,14316.0,3141.0,5079.0,1894.0] belongs to cluster 4
The data [1.0,3.0,31012.0,16687.0,5429.0,15082.0,439.0,1163.0] belongs to cluster 0
The data [1.0,3.0,3047.0,5970.0,4910.0,2198.0,850.0,317.0] belongs to cluster 5
The data [1.0,3.0,8607.0,1750.0,3580.0,47.0,84.0,2501.0] belongs to cluster 5
The data [1.0,3.0,3097.0,4230.0,16483.0,575.0,241.0,2080.0] belongs to cluster 4
The data [1.0,3.0,8533.0,5506.0,5160.0,13486.0,1377.0,1498.0] belongs to cluster 5
The data [1.0,3.0,21117.0,1162.0,4754.0,269.0,1328.0,395.0] belongs to cluster 3
The data [1.0,3.0,1982.0,3218.0,1493.0,1541.0,356.0,1449.0] belongs to cluster 5
The data [1.0,3.0,16731.0,3922.0,7994.0,688.0,2371.0,838.0] belongs to cluster 3
The data [1.0,3.0,29703.0,12051.0,16027.0,13135.0,182.0,2204.0] belongs to cluster 0
The data [1.0,3.0,39228.0,1431.0,764.0,4510.0,93.0,2346.0] belongs to cluster 0
The data [2.0,3.0,14531.0,15488.0,30243.0,437.0,14841.0,1867.0] belongs to cluster 1
The data [1.0,3.0,10290.0,1981.0,2232.0,1038.0,168.0,2125.0] belongs to cluster 5
The data [1.0,3.0,2787.0,1698.0,2510.0,65.0,477.0,52.0] belongs to cluster 5
...

此部分內容為測試集的聚類結果。因為我們選了40個樣本作為測試集,所以此部分輸出的內容一共有40行。

5.後續工作

本次測試是在單機上做的demo測試,資料集比較小,運算過程也比較快。其實當資料量增大以後,基本過程跟這是類似的,只需要將input改為叢集的資料路徑,然後再寫個簡單的shell指令碼,呼叫spark-submit,將任務提交到叢集即可。