1. 程式人生 > >在瀏覽器中進行深度學習:TensorFlow.js (八)生成對抗網路 (GAN

在瀏覽器中進行深度學習:TensorFlow.js (八)生成對抗網路 (GAN

Generative Adversarial Network 是深度學習中非常有趣的一種方法。GAN最早源自Ian Goodfellow的這篇論文。LeCun對GAN給出了極高的評價:

“There are many interesting recent development in deep learning…The most important one, in my opinion, is adversarial training (also called GAN for Generative Adversarial Networks). This, and the variations that are now being proposed is the most interesting idea in the last 10 years in ML, in my opinion.” – Yann LeCun

那麼我們就看看GAN究竟是怎麼回事吧:

 

如上圖所示,GAN包含兩個互相對抗的網路:G(Generator)和D(Discriminator)。正如它的名字所暗示的那樣,它們的功能分別是:

Generator是一個生成器的網路,它接收一個隨機的噪聲,通過這個噪聲生成圖片,記做G(z)。
Discriminator是一個鑑別器網路,判別一張圖片或者一個輸入是不是“真實的”。它的輸入x是資料或者圖片,輸出D(x)代表x為真實圖片的概率,如果為1,就代表100%是真實的圖片,而輸出為0,就代表不可能是真實的圖片。
在訓練過程中,生成網路G的目標就是儘量生成真實的圖片去欺騙判別網路D。而D的目標就是儘量把G生成的圖片和真實的圖片分別開來。這樣,G和D構成了一個動態的“博弈過程”。在最理想的狀態下,G可以生成足以“以假亂真”的圖片G(z)。對於D來說,它難以判定G生成的圖片究竟是不是真實的,因此D(G(z)) = 0.5。

最後,我們就可以使用生成器和隨機輸入來生成不同的資料或者圖片了。

上面的描述大家可能都能理解,但是把它變成數學語言,可能你就蒙B了。

“GANçš„æ ¸å¿ƒåŽŸç†â€çš„å›¾ç‰‡æœç´¢ç»“æžœ

如上圖所示,x是輸入,z是隨機噪聲。D(x)是鑑別器的判定資料為真的概率,D(G(z))是判定生成資料為真的概率。生成器希望這個D(G(z))越大越好,這個時候整個表示式的值應該變小。而鑑別器的目的是能夠有效區分真實資料和假資料,所以D(x)應該趨向於變大,D(G(z))趨向於變小,整個表示式就變大。也就是說訓練過程,生成器和辨別器互相對抗,一個使上述表示式變小,另一個使其變大,最後訓練趨向於平衡,而生成器這時候應該生成真假難辨的資料,這就是我們的最終目的。

 

上圖是GAN演算法訓練的具體過程,這裡我們不做過多的解釋,直接執行一個例子。

“GAN”的图片搜索结果

我們用MINST資料集來看看如何使用TensorflowJS來訓練一個GAN,模擬生成手寫數字。

程式碼見我的codepen

function gen(xs) {
const l1 = tf.leakyRelu(xs.matMul(G1w).add(G1b));
const l2 = tf.leakyRelu(l1.matMul(G2w).add(G2b));
const l3 = tf.tanh(l2.matMul(G3w).add(G3b));
return l3;
}

function disReal(xs) {
const l1 = tf.leakyRelu(xs.matMul(D1w).add(D1b));
const l2 = tf.leakyRelu(l1.matMul(D2w).add(D2b));
const logits = l2.matMul(D3w).add(D3b);
const output = tf.sigmoid(logits);
return [logits, output];
}

function disFake(xs) {
return disReal(gen(xs));
}
GAN的兩個網路分別用gen和disReal建立。gen是生成器網路,disReal是辨別器的網路。disFake是把生成資料用辨別器來辨別。這裡的網路使用leakyrelu。使得輸出在-inf到+inf,利用sigmoid對映到【0,1】,這是辨別器模型輸出一個0-1之間的概率。

“leaky relu”的图片搜索结果

 

通常我們會建立一個比生成器更復雜的鑑別器網路使得鑑別器有足夠的分辨能力。但在這個例子裡,兩個網路的複雜程度類似。

計算損失的函式使用 tf.sigmoidCrossEntropyWithLogits,值得注意的是,在最新的0.13版本中,這個交叉熵被移除了,你需要自己實現該方法。

訓練過程如下:

async function trainBatch(realBatch, fakeBatch) {
const dcost = dOptimizer.minimize(
() => {
const [logitsReal, outputReal] = disReal(realBatch);
const [logitsFake, outputFake] = disFake(fakeBatch);

const lossReal = sigmoidCrossEntropyWithLogits(ONES_PRIME, logitsReal);
const lossFake = sigmoidCrossEntropyWithLogits(ZEROS, logitsFake);
return lossReal.add(lossFake).mean();
},
true,
[D1w, D1b, D2w, D2b, D3w, D3b]
);
await tf.nextFrame();
const gcost = gOptimizer.minimize(
() => {
const [logitsFake, outputFake] = disFake(fakeBatch);

const lossFake = sigmoidCrossEntropyWithLogits(ONES, logitsFake);
return lossFake.mean();
},
true,
[G1w, G1b, G2w, G2b, G3w, G3b]
);
await tf.nextFrame();

return [dcost, gcost];
}
訓練使用了兩個optimizer,

第一步,計算實際資料的辨別結果和1的交叉熵,以及生成器生成資料的辨別結果和0的交叉熵。也就是說,我們希望辨別器儘可能的判斷出生成資料都是假的而實際資料都是真的。使得這兩個交叉熵的均值最小。
第二步開始對抗,要讓生成資料儘可能被判別為真。
下圖是某個訓練過程的損失:

 

這個是經過1000個迭代後的生成圖:

 

大家可以嘗試調整學習率,增加網路複雜度,加大迭代次數來獲得更好的生成模型。

GAN的學習其實還是比較複雜的,引數和損失選擇都不容易,好在有一些現成的工具可以使用,另外推薦大家去https://poloclub.github.io/ganlab/,提供了很直觀的GAN學習的過程。這個也是用TensorflowJS來實現的。

參考:

https://www.msra.cn/zh-cn/news/features/gan-20170511
https://zhuanlan .www.michenggw.com zhihu.com/p/24767059
http://blog.aylien.com/introduction-www.mhylpt.com generative-adversarial-networks-code-tensorflow/
https://github.com/carpedm20/DCGAN-tensorflow
https://blog.openai.www.gcyl152.com com/generative-models/
https://zhuanlan.zhihu.com/p/45200767
https://blog.csdn.net/heyc861221/article/details/80127148