1. 程式人生 > 其它 >Tensorflow2.0學習筆記-資料增強,斷點續訓

Tensorflow2.0學習筆記-資料增強,斷點續訓

技術標籤:筆記

資料增強

在小資料模型中,資料增強可以起到明顯的效果,本次使用的是mnist資料集單靠準確率去證明資料增強的效果是不可行的,需要自己在實際運用中體會。

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator

mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)  # 給資料增加一個維度,從(60000, 28, 28)reshape為(60000, 28, 28, 1)

image_gen_train = ImageDataGenerator(
    rescale=1. / 1.,  # 所有資料將乘以該數值,如為影象,分母為255時,可歸至0~1
    rotation_range=45,  # 隨機旋轉角度範圍。隨機45度旋轉
    width_shift_range=.15,  # 隨機寬度偏移量
    height_shift_range=.15,  # 隨機高度偏移
    horizontal_flip=False,  # 是否隨機水平翻轉
    zoom_range=0.5  # 調整縮放範圍。將影象隨機縮放閾量50%
)
image_gen_train.fit(x_train)

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

model.fit(image_gen_train.flow(x_train, y_train, batch_size=32),
          # 將輸入的x_train, y_train和資料打亂輸入,但是二者之間資料的一一對應性不變
          epochs=5, validation_data=(x_test, y_test),
          validation_freq=1)
model.summary()


斷點續訓

斷點續訓可以接著之前的訓練模型進行訓練。程式碼如下:

import tensorflow as tf
import os

mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

checkpoint_save_path = "./checkpoint/mnist.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):
    print('-------------load the model-----------------')
    # 讀取模型函式
    model.load_weights(checkpoint_save_path)
# 儲存模型的函式
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,  # 儲存的路徑名稱
                                                 save_weights_only=True,  #  是否只儲存模型引數
                                                 save_best_only=True)     #  是否只儲存最優結果  

history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,
                    callbacks=[cp_callback])
model.summary()

執行成功後,會在工程資料夾下生成checkpoint檔案。
在這裡插入圖片描述
再次執行程式會得到相應的輸出。

G:\anaconda\envs\tensorflow-2.0\python.exe G:/Pycharmprojects/tf2_notes-master/class4/MNIST_FC/p16_mnist_train_ex3.py
2021-01-07 17:16:46.259132: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library cudart64_100.dll
-------------load the model-----------------
......
#  因為是之前訓練好的模型所以得到的準確率先由大到小,再由小到大
   32/60000 [..............................] - ETA: 47:16 - loss: 0.0069 - sparse_categorical_accuracy: 1.0000
  352/60000 [..............................] - ETA: 4:25 - loss: 0.0603 - sparse_categorical_accuracy: 0.9830 
  672/60000 [..............................] - ETA: 2:23 - loss: 0.0491 - sparse_categorical_accuracy: 0.9851
  992/60000 [..............................] - ETA: 1:39 - loss: 0.0422 - sparse_categorical_accuracy: 0.9879
 1280/60000 [..............................] - ETA: 1:19 - loss: 0.0395 - sparse_categorical_accuracy: 0.9883
 1568/60000 [..............................] - ETA: 1:06 - loss: 0.0434 - sparse_categorical_accuracy: 0.9872
 1888/60000 [..............................] - ETA: 56s - loss: 0.0443 - sparse_categorical_accuracy: 0.9857