1. 程式人生 > >生成對抗網路(GAN)應用於影象分類

生成對抗網路(GAN)應用於影象分類

  近年來,深度學習技術被廣泛應用於各類資料處理任務中,比如影象、語音和文字。而生成對抗網路(GAN)和強化學習(RL)已經成為了深度學習框架中的兩顆“明珠”。強化學習主要用於決策問題,主要的應用就是遊戲,比如deepmind團隊的AlphaGo。因為我的研究方向是影象的有監督分類問題,故本文主要講解生成對抗網路及其在分類問題方面的應用。

生成對抗網路框架

  生成對抗網路(Generative adversarial networks,簡稱為GAN)是2014年由Ian J. Goodfellow首先提出來的一種學習框架,說起Ian J. Goodfellow本人,可能大家印象不深刻,但他的老師正是“深度學習三巨頭”之一的Yoshua Bengio(另外兩位分別是Hinton和LeCun),值得一提的是,Theano深度學習框架也是由他們團隊開發的,開啟了符號計算的先河。關於GAN在機器學習領域的地位,在這裡引用一段Lecun的評價,

  “There are many interesting recent development in deep learning…The most important one, in my opinion, is adversarial training (also called GAN for Generative Adversarial Networks). This, and the variations that are now being proposed is the most interesting idea in the last 10 years in ML, in my opinion.”

  傳統的生成模型都需要先定義一個概率分佈的引數表示式,然後通過最大化似然函式來訓練模型,比如深度玻爾茲曼機(RBM)。這些模型的梯度表示式展開式中通常含有期望項,導致很難得到準確解,一般需要近似,比如在RBM中,利用Markov chain 的收斂性,可以得到符合給定分佈下的隨機樣本。為了克服求解準確性和計算複雜性的困難,J牛創造性的提出來了生成對抗網路。GAN模型不需要直接表示資料的似然函式,卻可以生成與原始資料有相同分佈的樣本。
  與常規的深度學習模型(比如cnn、dbn、rnn)不同,GAN模型採用了兩個獨立的神經網路,分別稱為“generator”和“discriminator”,生成器用於根據輸入噪聲訊號生成‘看上去和真實樣本差不多’的高維樣本,判別器用於區分生成器產生的樣本和真實的訓練樣本(屬於一個二分類問題)。其模型結構框架如下,


GANs是基於一個minimax機制而不是通常的優化問題,它所定義的損失函式是關於判別器的最大化和生成器的最小化,作者也證明了GAN模型最終能夠收斂,此時判別器模型和生成器模型分別取得最優解。記 表示樣本資料, 表示生成器的輸入噪聲分佈, 表示噪聲到樣本空間的對映, 表示 屬於真實樣本而不是生成樣本的概率,那麼GAN模型可以定義為如下的優化問題,
從以上公式可以看出,在模型的訓練過程中,一方面需要修正判別器D,使值函式V最大化,也即使得 最大化和 最小化,其數學意義即最大化判別器分類訓練樣本和生成樣本的正確率,另一方面需要修正生成器G,使值函式V最小化,也即使得 最大化,其數學意義即生成器要儘量生成和訓練樣本非常相似的樣本,這也正是GAN名字中Adversarial的由來。J牛提出了交替優化D和G(對D進行k步優化,對G進行1步優化),具體的訓練過程如下,

GAN在分類問題方面的應用

  早期的GAN模型主要應用於無監督學習任務,即生成和訓練樣本有相同分佈的資料,可以為1維訊號或者二維影象。將GAN應用於分類問題時,需要對網路做改動,這裡簡單講解一下已有的兩篇文章中提出的方案,“Improved Techniques for Training GANs”和“Semantic Segmentation using Adversarial Networks”,前者可以歸類於半監督分類演算法,而後者則屬於有監督分類演算法。

半監督分類方法

  將GAN應用於半監督分類任務時,只需要對最初的GAN的結構做稍微改動,即把discriminator模型的輸出層替換成softmax分類器。假設訓練資料有c類,那麼在訓練GAN模型的時候,可以把generator模擬出來的樣本歸為第c+1類,而softmax分類器也增加一個輸出神經元,用於表示discriminator模型的輸入為“假資料”的概率,這裡的“假資料”具體指generator生成的樣本。因為該模型可以利用有標籤的訓練樣本,也可以從無標籤的生成資料中學習,所以稱之為“半監督”分類。定義損失函式如下,其中是一個標準的GAN優化問題,關於該模型的具體訓練方法可以參見原文。

有監督分類方法

  可想而知,在應用於基於畫素的有監督分類問題時(文章中的訓練資料集類似於人臉識別資料集,區別在於單幅影象的標籤y和輸入人臉影象大小相同),GAN中的生成器模型是沒有什麼作用的。原作者所提出的網路框架包含了兩個分類器模型,其中一個用於對單幅影象進行基於畫素的分類,另外一個分類器也稱作對抗網路,用於區分標籤圖和預測出來的概率圖,引入對抗網路的目的是使得得到的概率預測圖更符合真實的標籤圖,具體的網路結構如下,


  記訓練影象為 表示預測出來的概率圖, 表示對抗網路預測y是x的真實標籤圖的概率, 分別表示segmentation模型和adversarial模型的引數,那麼損失函式可以定義如下,
其中, 表示預測的概率圖 和真實標籤圖y之間的multi-class cross entropy損失,而 ,即表示binary cross entropy 損失。與GAN的訓練方法類似,這裡的模型訓練也是通過迭代訓練adversarial模型和segmentation模型來完成的。在訓練adversarial模型時,等價於優化如下表達式,其物理意義是使得adversarial模型對概率圖和真實標籤圖的區分能力更強。

在訓練segmentation模型時,等價於優化如下表達式,其物理意義是使得生成的概率圖不僅和對應標籤圖相似,而且adversarial模型很難區分的開。

參考資料:Generative Adversarial Nets, Ian J. Goodfellow.
     Improved Techniques for Training GANs. Tim Salimans, Ian Goodfellow.
     Semantic Segmentation using Adversarial Networks. Pauline Luc.