Load MNIST Data Set
載入60000筆訓練數據與10000筆測試數據。1
(train_feature, train_label), (test_feature, test_label) = mnist.load_data()
Data Preprocessing
Reshape
將28x28特徵值Raw Data(圖片)轉換為32位元浮點數一維數據。1
2train_feature_vector = train_feature.reshape(len(train_feature), 784).astype('float32')
test_feature_vector = test_feature.reshape(len(test_feature), 784).astype('float32')
Feature Normalization
對特徵值進行正規化處理,也就是將數據按比例縮放至[0, 1]區間,且不改變其原始分佈,以收斂速度與預測精準度。1
2train_feature_normal = train_feature_vector / 255
test_feature_normal = test_feature_vector / 255
One-Hot Encoding
對離散型資料標籤進行獨熱編碼處理轉換為布林陣列,便於進行矩陣運算。1
2train_label_onehot = np_utils.to_categorical(train_label)
test_label_onehot = np_utils.to_categorical(test_label)
Model Definition
定義循序模型之結構、訓練方法、準確率評估1
model = Sequential()
Layer Definition
定義輸入層、隱藏層、輸出層 :
- Units : 784 -> 256 -> 10
- 常態分佈亂數初始化weight&bias
- 隱藏層活化函數使用ReLU
- 輸出層活化函數使用Softmax
1 | model.add( Dense(units=256, input_dim=784, init='normal', activation='relu') ) |
Training Definition
定義訓練方法 :
- 損失函數為 CrossEntropy Loss
- 優化器使用 Adam
- 驗證數據分割比例為0.2(將6萬筆訓練數據進一步分割為4.8萬筆訓練數據和1.2萬筆驗證數據)
- 訓練週期(epoch)為10
- 每批次樣本數為200(因此一個訓練週期為4.8萬/200=240批次)
1 | model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) |
Accuracy Evaluation
1 | accuracy = model.evaluate(test_feature_normal, test_label_onehot) |
Save & Load Model
1 | # save |
Full Code
1 | #!/usr/bin/env python3 |
Test Your Own Handwritten Numbers Image
為了讓訓練好的模型預測看看資料集以外的圖片,我用FireAlpaca「手寫」了10張28x28的數字圖片😆,並將圖片命名為「真實數字_圖片順序編碼.jpg」這樣的格式,例如「8_image2.jpg」代表這張圖片為我製作的第2張圖片,內容為數字8,這樣的命名規則是為了方便讀取圖片時能從檔名擷取其label。
import blob & opencv
1 | from glob import glob |
P.S. 在VS Code中若只寫「import cv2」的話會報錯…
data preprocessing
1 | def get_test_process(files): |
Prediction
1 | model = load_model("mdl_mlp_mnist.h5") |
Result
哎呀,其中一張數字8的圖片預測錯誤😂
和數據集的圖片比較起來,我的手寫圖片經過影像處理完筆跡變得超細,或許特徵相對不那麼明顯吧,把原圖多點幾個像素上去再預測一次就過了呢。
*測試程式指定隨機亂數種子是為了再現性
*下載MNIST數據集時若發生 ssl.SSLError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed,在Terminal 執行以下命令 :1
/Applications/Python\ 3.6/Install\ Certificates.command