1. 程式人生 > >本地匯入Mnist的資料集的方法

本地匯入Mnist的資料集的方法

 完整程式碼的下載路徑:https://download.csdn.net/download/lxiao428/10714886

很多人在介紹Mnist資料集的時候都是通過庫在網上下載,我以前也是這麼做的,但是今天發現遠端伺服器關閉連線了,而我本地又有這個Mnist資料集,我就想怎麼講訓練資料和測試資料匯入到我的程式碼訓練中,網上找了好久都沒有辦法,so,搜腸刮肚找到的這個辦法。

#載入Mnist資料集
from keras.datasets import mnist
import gzip
import os
import numpy

local_file = "F:\python\DeepLearning"

#(train_images, train_labels),(test_images, test_labels) = mnist.load_data()
TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'  #訓練集影象的檔名
TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'  #訓練集label的檔名
TEST_IMAGES = 't10k-images-idx3-ubyte.gz'    #測試集影象的檔名
TEST_LABELS = 't10k-labels-idx1-ubyte.gz'    #測試集label的檔名

#主要是下面的兩個函式實現的:

def extract_images(filename):

def extract_labels(filename, one_hot=False):


train_images = extract_images(os.path.join(local_file,TRAIN_IMAGES))
train_labels = extract_labels(os.path.join(local_file,TRAIN_LABELS))
test_images = extract_images(os.path.join(local_file,TEST_IMAGES))
test_labels = extract_labels(os.path.join(local_file,TEST_LABELS))

#網路架構
'''
神經網路的核心元件是layer,它是一種資料處理模組,可以看成是資料過濾器。
'''
from keras import models
from keras import layers
network = models.Sequential()
network.add(layers.Dense(512, activation='relu',input_shape=(28*28,)))
network.add(layers.Dense(10, activation='softmax'))

#編譯步驟
'''
要想訓練網路,需要選擇變非同步驟的三個引數:
(1)損失函式(loss):衡量網路在訓練資料集上的效能;
(2)優化器(optimizer):基於訓練資料和損失函式更新網路的機制;
(3)訓練和測試中的監控指標(metric):如精度
'''
network.compile(optimizer='rmsprop',
                loss='categorical_crossentropy',
                metrics=['accuracy'])

#資料預處理
train_images = train_images.reshape((60000, 28*28))
train_images = train_images.astype('float32')/255

test_images = test_images.reshape((10000, 28*28))
test_images = test_images.astype('float32')/255

#準備標籤
from keras.utils import to_categorical
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)

#訓練網路
network.fit(train_images, train_labels, epochs = 5, batch_size = 256)

#效能評估
train_loss, train_acc = network.evaluate(test_images, test_labels)
print('test_acc:', train_acc)
print('test_error:', train_loss)