優化版本對生成對抗網路生成手寫數字集(附程式碼詳解)
阿新 • • 發佈:2018-11-05
# 先匯入必要的庫 import os import cv2 import tensorflow as tf import numpy as np # 把結果儲存到本地的一個庫 import pickle import matplotlib.pyplot as plt from tensorflow.examples.tutorials.mnist import input_data # 讀取mnist資料集 mnist = input_data.read_data_sets("MNIST_DATA") # 圖片的大小為28x28, 即784 img_size = 784 # 輸入的噪聲,也可以設定成別的值 noise_size = 100 # 生成網路的隱藏層神經元個數 g_units = 128 # 判別網路的隱藏層神經元個數 d_units = 128 # 學習率 learning_rate = 0.001 # 每個batch的大小 batch_size = 64 # 迭代的輪數, 這裡每個epoch會遍歷一次訓練的資料集 epochs = 300 # 對生成的圖片取樣儲存 n_sample = 25 samples = [] # 獲取輸入的函式 def get_input(real_size, noise_size): """ :param real_size: 真實圖片的大小 :param noise_size: 噪聲的長度 :return: 返回兩個佔位符,其實就是判別網路和生成網路的輸入 """ real_img = tf.placeholder(tf.float32, [None, real_size]) noise_img = tf.placeholder(tf.float32, [None, noise_size]) return real_img, noise_img # 生成器共有兩層結構,noise_img---->n_units----->out_dim def get_generator(noise, n_units, out_dim=img_size, reuse=False, alpha=0.01): """ 實現生成網路 :param noise: 生成網路的輸入 :param n_units: 生成網路的隱藏層神經元個數 :param out_dim: 生成網路的輸出 [None, 784] :param reuse: 是否重複使用網路的各種引數 :param alpha: LeakRelu的引數 :return: 生成模型未經啟用的輸出,和tanh啟用之後的輸出 """ # 建立一個名稱空間, 名稱為generator with tf.variable_scope("generator", reuse=reuse): # 第一層隱藏層 hidden1 = tf.layers.dense(noise, n_units) # 啟用函式和dropout hidden1 = tf.maximum(alpha * hidden1, hidden1) hidden1 = tf.layers.dropout(hidden1, rate=0.2) # 網路未經過啟用函式之前輸出的結果 logits = tf.layers.dense(hidden1, out_dim) out_puts = tf.tanh(logits) return logits, out_puts # 判別器的結構: img---->n_units---1 def get_discriminator(img, n_units, reuse=False, alpha=0.01): """ :param img: 輸入影象的大小 :param n_units: 判別網路隱藏層神經元的個數 :param reuse: 是否重用模型的引數 :param alpha: LeakRelu的引數 :return: 判別模型未經啟用的輸出,和tanh啟用之後的輸出 """ # 建立一個名稱空間, 名稱為discriminator with tf.variable_scope("discriminator", reuse=reuse): # 第一層結構 hidden1 = tf.layers.dense(img, n_units) # 啟用函式 hidden1 = tf.maximum(alpha * hidden1, hidden1) # 輸出層 logits = tf.layers.dense(hidden1, 1) # 使用sigmoid啟用函式 outputs = tf.sigmoid(logits) return logits, outputs # tf.reset_default_graph函式用於清除預設圖形堆疊並重置全域性預設圖形 tf.reset_default_graph() # 接受兩個placeholder real_img, noise_img = get_input(img_size, noise_size) # 呼叫生成網路 g_logits, g_outputs = get_generator(noise_img, g_units) # 判別網路對真實圖片的判別結果 d_logits_real, d_outputs_real = get_discriminator(real_img, d_units) # 判別網路對生成圖片的判別結果, resue表示使用和上面相同的結構和引數 d_logits_fake, d_outputs_fake = get_discriminator(g_outputs, d_units, reuse=True) # 計算損失,判別網路對真實圖片的損失,tf.ones_like(x), 會生成形狀如x, 數值為1的向量, tf.zeros_like(x) 同理 # 從判別器的角度我們希望判別網路能把真實的圖片預測為1,把生成的圖片預測為零 d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real, labels=tf.ones_like(d_logits_real))) d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, labels=tf.zeros_like(d_logits_fake))) # 將兩部分的損失合起來就是判別網路的損失值了 d_loss = tf.add(d_loss_real, d_loss_fake) # 從生成器的角度來看,我們又希望生成器能生成接近真實的圖片,也就是讓判別器儘可能的把生成的圖片也預測為1,這就是對抗的思想 g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, labels=tf.ones_like(d_logits_fake))) # 優化過程, 通過名稱空間的特性取出生成器和判別器中的所有的引數,以便於優化其損失 train_vars = tf.trainable_variables() g_vars = [var for var in train_vars if var.name.startswith("generator")] d_vars = [var for var in train_vars if var.name.startswith("discriminator")] # 優化操作,使用Adam函式進行優化,注意後面的變數列表要與正在優化的損失對應 d_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(d_loss, var_list=d_vars) g_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(g_loss, var_list=g_vars) # 用於儲存生成的圖片,便於直觀的看到生成模型的效果 if not os.path.exists('gen_pictures/'): os.makedirs('gen_pictures/') # 儲存模型,只需要儲存生成模型即可 saver = tf.train.Saver(var_list=g_vars) # 開啟一個會話, 開始訓練過程 with tf.Session() as sess: # 初始化所有的變數 sess.run(tf.global_variables_initializer()) for epoch in range(epochs): # 每個epoch會把訓練樣本過一遍 for batch_i in range(mnist.train.num_examples // batch_size): # 從真實樣本資料中取出一個batch, 表示真實的圖片 batch = mnist.train.next_batch(batch_size) batch_image = batch[0].reshape((batch_size, 784)) # 同樣的構造一個batch_size 的噪聲 batch_noise = np.random.uniform(-1, 1, size=(batch_size, noise_size)) #開始進行迭代訓練 sess.run(d_train_opt, feed_dict={real_img: batch_image, noise_img: batch_noise}) sess.run(g_train_opt, feed_dict={noise_img: batch_noise}) # 列印每個epoch的生成網路的損失,和判別網路的損失 train_loss_d = sess.run(d_loss, feed_dict={real_img: batch_image, noise_img: batch_noise}) train_loss_g = sess.run(g_loss, feed_dict={noise_img: batch_noise}) print("Iterations " + str(epoch) + ", the discrimator loss is: %.4f, generator loss is: %.4f" %(train_loss_d, train_loss_g)) # 使用更新後的引數生成一定數量的樣本 sample_noise = np.random.uniform(-1, 1, size=(n_sample, noise_size)) _, gen_samples = sess.run(get_generator(noise_img, g_units, img_size, reuse=True), feed_dict={noise_img: sample_noise}) # 從生成的圖片中隨機的選取一張儲存下來 single_picture = gen_samples[np.random.randint(0, n_sample)] # 生成圖片的啟用函式是tanh(-1, 1) --->(0, 2) ---->(0, 255) single_picture = (np.reshape(single_picture, (28, 28)) + 1) * 177.5 # 儲存圖片 cv2.imwrite("gen_pictures/A{}.jpg".format(str(epoch)), single_picture) samples.append(gen_samples) # 儲存模型 saver.save(sess, "./checkpoints/generator.ckpt") # 將生成的圖片結果寫入檔案 with open("train_samples.pkl", "wb") as f: pickle.dump(samples, f)