1. 程式人生 > >生成對抗網路GANs工作原理

生成對抗網路GANs工作原理

  生成對抗網路是一種基於可微生成器的生成式建模方法。

  生成對抗網路基於博弈模型,其中生成器網路(generator network)必須與其對手判別器(discriminator)競爭。生成器直接生成樣本 x = g(z ; θ(g)) 。其對手,判別器網路會嘗試區分生成器生成的樣本和訓練資料中抽取的樣本。生成器由 d(x  ; θ(g))生成一個概率值來判別樣本 x 是從訓練資料中抽取的樣本還是由生成器生成的 ‘贗品’ 。

  生成對抗網路是一種生成模型,GANs的結構和我們之前見到的神經網路略為不同。大體上來說,GANs有生成器Generator和辨別器Discriminator組成,基本的結構圖如下:

GANs結構示意圖

工作原理 

  我們通常使用兩個優化演算法來訓練GANs。判別器是一個普通的神經網路分類器,訓練的過程中,我們使用辨別器 (discriminator) 學習引導生成器。

  判別器:

  在訓練的過程中,我們向辨別器discriminator輸入的資料一半來自於真實的訓練資料,另一半來自於生成器生成的假影象。在訓練的過程中,對於真實資料,判別器嘗試向其分配一個接近1的概率(為更好泛化,一般會使用smooth引數將labels設為略小於1的值,如0.9);而對於生成器生成的‘贗品’,判別器嘗試向其分配一個接近0的概率。也就是說,對於真實資料,我們使用label=1計算代價函式來訓練判別器,其代價函式的計算方法為:

d_loss_real tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_reallabels=tf.ones_like(d_logits_real(smooth)))

對於生成器,我們使用label=0計算代價函式來訓練判別器,其代價函式的計算方法為:

d_loss_fake tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fakelabels=tf.zeros_like(d_logits_fake)))

所以判別器的代價函式為:d_loss d_loss_real d_loss_fake

  生成器:

  與此同時,生成器嘗試做相反的事情,它經訓練嘗試輸出能使辨別器分配接近概率1的樣本。生成器的代價函式為

g_loss tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fakelabels=tf.ones_like(d_logits_fake)))

  隨著以上訓練的進行,判別器‘被迫’增強自身的判別能力,而生成器‘被迫’生成越來越逼真的輸出,以欺騙判別器。理論上,最終生成器和判別器會達到一種均衡“納什均衡”。

 Discriminator和Generator損失計算

  GANs和很多其他模型不同,GANs在訓練時需要同時執行兩個優化演算法,我們需要為discriminator和generator分別定義一個優化器,一個用來來最小化discriminator的損失,另一個用來最小化generator的損失。即loss = d_loss + g_loss

  d_loss計算方法:

  對於辨別器discriminator,其損失等於真實圖片和生成圖片的損失之和,即 d_loss = d_loss_real + d_loss_fake , losses 均由交叉熵計算而得。在 tensorflow 中可使用以下函式:

tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels)

  在計算真實資料產生的損失d_loss_real時,我們希望辨別器discriminator輸出1;而在計算生成器生成的 ‘假’ 資料所產生的損失d_loss_fake時,我們希望discriminator輸出0.

  因此,對於真實資料,在計算其損失時,將上式中的labels全部都設為1,因為它們都是真實的。為了是增強辨別器discriminator的泛化能力,可以將labels設為0.9,而不是1.0。

  對於生成器生成的‘假’資料,在計算其損失d_loss_fake時,將上式中的labels全部設為0。

  g_loss計算方法:

  最後,生成器generator的損失用 '假' 資料的logits(即d_logits_fake),但是,現在所有的labels全部設為1(即我們希望生成器generator輸出1)。這樣,通過訓練,生成器generator試圖 ‘騙過’ 辨別器discriminator。

例項

  一個使用GANs的來生成MNIST資料的程式碼示例