Tensorflow2.0學習筆記-資料增強,斷點續訓
阿新 • • 發佈:2021-01-08
技術標籤:筆記
資料增強
在小資料模型中,資料增強可以起到明顯的效果,本次使用的是mnist資料集單靠準確率去證明資料增強的效果是不可行的,需要自己在實際運用中體會。
import tensorflow as tf from tensorflow.keras.preprocessing.image import ImageDataGenerator mnist = tf.keras.datasets.mnist (x_train, y_train), (x_test, y_test) = mnist.load_data() x_train, x_test = x_train / 255.0, x_test / 255.0 x_train = x_train.reshape(x_train.shape[0], 28, 28, 1) # 給資料增加一個維度,從(60000, 28, 28)reshape為(60000, 28, 28, 1) image_gen_train = ImageDataGenerator( rescale=1. / 1., # 所有資料將乘以該數值,如為影象,分母為255時,可歸至0~1 rotation_range=45, # 隨機旋轉角度範圍。隨機45度旋轉 width_shift_range=.15, # 隨機寬度偏移量 height_shift_range=.15, # 隨機高度偏移 horizontal_flip=False, # 是否隨機水平翻轉 zoom_range=0.5 # 調整縮放範圍。將影象隨機縮放閾量50% ) image_gen_train.fit(x_train) model = tf.keras.models.Sequential([ tf.keras.layers.Flatten(), tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dense(10, activation='softmax') ]) model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), metrics=['sparse_categorical_accuracy']) model.fit(image_gen_train.flow(x_train, y_train, batch_size=32), # 將輸入的x_train, y_train和資料打亂輸入,但是二者之間資料的一一對應性不變 epochs=5, validation_data=(x_test, y_test), validation_freq=1) model.summary()
斷點續訓
斷點續訓可以接著之前的訓練模型進行訓練。程式碼如下:
import tensorflow as tf import os mnist = tf.keras.datasets.mnist (x_train, y_train), (x_test, y_test) = mnist.load_data() x_train, x_test = x_train / 255.0, x_test / 255.0 model = tf.keras.models.Sequential([ tf.keras.layers.Flatten(), tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dense(10, activation='softmax') ]) model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), metrics=['sparse_categorical_accuracy']) checkpoint_save_path = "./checkpoint/mnist.ckpt" if os.path.exists(checkpoint_save_path + '.index'): print('-------------load the model-----------------') # 讀取模型函式 model.load_weights(checkpoint_save_path) # 儲存模型的函式 cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, # 儲存的路徑名稱 save_weights_only=True, # 是否只儲存模型引數 save_best_only=True) # 是否只儲存最優結果 history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1, callbacks=[cp_callback]) model.summary()
執行成功後,會在工程資料夾下生成checkpoint檔案。
再次執行程式會得到相應的輸出。
G:\anaconda\envs\tensorflow-2.0\python.exe G:/Pycharmprojects/tf2_notes-master/class4/MNIST_FC/p16_mnist_train_ex3.py 2021-01-07 17:16:46.259132: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library cudart64_100.dll -------------load the model----------------- ...... # 因為是之前訓練好的模型所以得到的準確率先由大到小,再由小到大 32/60000 [..............................] - ETA: 47:16 - loss: 0.0069 - sparse_categorical_accuracy: 1.0000 352/60000 [..............................] - ETA: 4:25 - loss: 0.0603 - sparse_categorical_accuracy: 0.9830 672/60000 [..............................] - ETA: 2:23 - loss: 0.0491 - sparse_categorical_accuracy: 0.9851 992/60000 [..............................] - ETA: 1:39 - loss: 0.0422 - sparse_categorical_accuracy: 0.9879 1280/60000 [..............................] - ETA: 1:19 - loss: 0.0395 - sparse_categorical_accuracy: 0.9883 1568/60000 [..............................] - ETA: 1:06 - loss: 0.0434 - sparse_categorical_accuracy: 0.9872 1888/60000 [..............................] - ETA: 56s - loss: 0.0443 - sparse_categorical_accuracy: 0.9857