生成對抗網路GANs工作原理
生成對抗網路是一種基於可微生成器的生成式建模方法。
生成對抗網路基於博弈模型,其中生成器網路(generator network)必須與其對手判別器(discriminator)競爭。生成器直接生成樣本 x = g(z ; θ(g)) 。其對手,判別器網路會嘗試區分生成器生成的樣本和訓練資料中抽取的樣本。生成器由 d(x ; θ(g))生成一個概率值來判別樣本 x 是從訓練資料中抽取的樣本還是由生成器生成的 ‘贗品’ 。
生成對抗網路是一種生成模型,GANs的結構和我們之前見到的神經網路略為不同。大體上來說,GANs有生成器Generator和辨別器Discriminator組成,基本的結構圖如下:
GANs結構示意圖
工作原理
我們通常使用兩個優化演算法來訓練GANs。判別器是一個普通的神經網路分類器,訓練的過程中,我們使用辨別器 (discriminator) 學習引導生成器。
判別器:
在訓練的過程中,我們向辨別器discriminator輸入的資料一半來自於真實的訓練資料,另一半來自於生成器生成的假影象。在訓練的過程中,對於真實資料,判別器嘗試向其分配一個接近1的概率(為更好泛化,一般會使用smooth引數將labels設為略小於1的值,如0.9);而對於生成器生成的‘贗品’,判別器嘗試向其分配一個接近0的概率。也就是說,對於真實資料,我們使用label=1計算代價函式來訓練判別器,其代價函式的計算方法為:
d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real, labels=tf.ones_like(d_logits_real) * (1 - smooth)))
對於生成器,我們使用label=0計算代價函式來訓練判別器,其代價函式的計算方法為:
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, labels=tf.zeros_like(d_logits_fake)))
所以判別器的代價函式為:d_loss = d_loss_real + d_loss_fake
生成器:
與此同時,生成器嘗試做相反的事情,它經訓練嘗試輸出能使辨別器分配接近概率1的樣本。生成器的代價函式為
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, labels=tf.ones_like(d_logits_fake)))
隨著以上訓練的進行,判別器‘被迫’增強自身的判別能力,而生成器‘被迫’生成越來越逼真的輸出,以欺騙判別器。理論上,最終生成器和判別器會達到一種均衡“納什均衡”。
Discriminator和Generator損失計算
GANs和很多其他模型不同,GANs在訓練時需要同時執行兩個優化演算法,我們需要為discriminator和generator分別定義一個優化器,一個用來來最小化discriminator的損失,另一個用來最小化generator的損失。即loss = d_loss + g_loss
d_loss計算方法:
對於辨別器discriminator,其損失等於真實圖片和生成圖片的損失之和,即 d_loss = d_loss_real + d_loss_fake , losses 均由交叉熵計算而得。在 tensorflow 中可使用以下函式:
tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels)
在計算真實資料產生的損失d_loss_real時,我們希望辨別器discriminator輸出1;而在計算生成器生成的 ‘假’ 資料所產生的損失d_loss_fake時,我們希望discriminator輸出0.
因此,對於真實資料,在計算其損失時,將上式中的labels全部都設為1,因為它們都是真實的。為了是增強辨別器discriminator的泛化能力,可以將labels設為0.9,而不是1.0。
對於生成器生成的‘假’資料,在計算其損失d_loss_fake時,將上式中的labels全部設為0。
g_loss計算方法:
最後,生成器generator的損失用 '假' 資料的logits(即d_logits_fake),但是,現在所有的labels全部設為1(即我們希望生成器generator輸出1)。這樣,通過訓練,生成器generator試圖 ‘騙過’ 辨別器discriminator。
例項
一個使用GANs的來生成MNIST資料的程式碼示例。