1. 程式人生 > >Spark機器學習(8):LDA主題模型算法

Spark機器學習(8):LDA主題模型算法

算法 ets 思想 dir 骰子 cati em算法 第一個 不同

1. LDA基礎知識

LDA(Latent Dirichlet Allocation)是一種主題模型。LDA一個三層貝葉斯概率模型,包含詞、主題和文檔三層結構。

LDA是一個生成模型,可以用來生成一篇文檔,生成時,首先根據一定的概率選擇一個主題,然後在根據概率選擇主題裏面的一個單詞,這樣反復進行,就可以生成一篇文檔;反過來,LDA又是一種非監督機器學習技術,可以識別出大規模文檔集或語料庫中的主題。

LDA原始論文給出了一個很簡單的例子。Arts、Budgets、Children、Education是4個主題,下面是每一個主題包含的單詞。

技術分享

然後就可以隨機選擇主題,以及每個主題裏面的單詞,重復多次後就生成了一篇文檔,其中不同的顏色表示單詞來自不同的主題。

技術分享

可見,文檔和單詞是可見的,而主題是隱藏的。

文檔裏某個單詞出現的概率可以用公式表示:

技術分享

其中d是文檔,w是單詞,z是主題,k是主題數量。可以想象成三個矩陣:

技術分享

第一個矩陣表示每個文檔裏面每個單詞出現的概率,第二個矩陣表示每個文檔裏面每個主題出現的概率,第三個矩陣表示每個主題裏面每個詞語出現的概率。在機器學習時,根據文檔集,我們可以計算出第一個矩陣,要求的是第二個矩陣和第三個矩陣。

2. 極大似然估計

極大似然估計的基本思想是,從總體抽取n個樣本之後,最合理的參數估計量應該是使得這批樣本出現的概率最大的參數估計量。比如說你在一個小城市,很少看見美國人,偶然看見了幾個美國人身材都很高,這時就可以估計美國人普遍身材很高,因為只有這樣你看到幾個美國人身材都很高這件事出現的概率才最大。

3. EM方法

EM即Exception Maximization,是機器學習的重要算法之一,在機器學習中有著重要的作用。簡單的說,EM方法就是解決這樣的問題:想估計兩個參數A和B,這兩個參數都是未知的,知道了參數A就能得到參數B,反過來知道了參數B就能得到參數A,這時我們就可以先給A一個初始值,然後計算出B,然後再根據計算出的B再計算A,這樣反復叠代下去,一直到收斂為止。在數學上可以證明這種方法時有效的。

4. Beta分布和Dirichlet分布

Beta分布是二項分布的共軛先驗分布:

技術分享

比如拋硬幣,3次出現正面,2次出現背面,a=3,b=2,就可以得到一個概率分布圖,從概率分布圖上可以看出,x=0.6時函數取得最大值,於是就可以認為x的值很可能接近於0.6,又扔了5次,2次正面,3次背面,a=5,b=5,又可以得到一個新的概率分布圖,x=0.5時函數取得最大值,此時可以認為x的值很可能接近於0.5。

Dirichlet分布和Beta分布類似,是Beta分布在高維度的推廣:

技術分享

比如扔骰子,扔了60次,6個面,各出現10次,可以得到一個概率分布圖,x=(1/6,1/6,1/6,1/6,1/6,1/6)時函數取得最大值,x的值很可能接近於(1/6,1/6,1/6,1/6,1/6,1/6)。

5. LDA的EM算法

具體到LDA,采用EM方法的步驟如下:

(1) 給矩陣wk和kj隨機賦值,其中wk是每個主題中每個單詞出現的次數,kj是每個文檔中每個主題出現的次數,雖然這些次數還只是隨機數,我們還是可以根據這些次數,利用Dirichlet分布計算出每個主題中每個單詞最可能出現的概率,以及每個文檔中每個主題最可能出現的概率,也就相當於給上面的第二個和第三個矩陣初始值;

(2) 對於文檔中的一個單詞,計算出是由哪個主題產生的,因為可能有多個主題都會產生這個單詞,那麽它到底是屬於哪個主題呢?這時就要用到極大似然估計了。計算出每個主題產生這個單詞的概率:

技術分享

然後找出概率最大的那個主題,認為這個單詞就是這個主題產生的,這在EM方法中屬於E-STEP;

(3) 由於確定了這個單詞是哪個主題產生的,相當於Dirichlet分布中a的值發生了改變,於是計算出新的概率矩陣(即上面的第二個和第三個矩陣),這在EM方法中屬於M-STEP。

重復步驟(2)和(3),就可以得到最終的概率矩陣(即上面的第二個和第三個矩陣),機器學習結束。

6. MLlib中LDA的實現

MLlib使用GraphX實現LDA。有兩類節點:詞節點和文檔節點。每個詞節點上存儲一個單詞,以及這個單詞屬於每一個主題的概率;每個文檔節點上存儲一個文檔,以及這個文檔屬於每個主題的概率。例如下圖,存儲了3個單詞和兩個文檔,hockey和system在Article1中出現,launch和system在Article2中出現。

技術分享

叠代過程中,文檔節點通過收集鄰居節點(即詞節點)的數據來更新自己的主題概率,如下圖所示。

技術分享

Spark機器學習(8):LDA主題模型算法