Deep Learning

【やってみた】ゼロから作るDeep Learning① (3.6)

はじめに

前回の記事

【やってみた】ゼロから作るDeep Learning① (3.5)はじめに 前回の記事 https://shotslog.com/try-zerotsuku1-3-1-3-4 この記事は、...

この記事は、ゼロから作るDeep Learning(以下ゼロつく)のアウトプット学習用に書いています。
今回は3章の6を扱います。
テーマは手書き数字認識についてです。

同じように機械学習・ディープラーニングを学習している方にもわかりやすいように書きたいと思います。
理解を深めるため、ぜひ書籍と併せて読んでいただければと思います。

手書き数字認識

これまでニューラルネットワークの仕組みについて触れてきたので、具体的な問題に取り組みます。
手書き数字の画像に対して分類し、それが正しいか推論します。

MNISTデータセット

今回使うのはMNISTというデータセットです。
MNISTは有名な手書き数字の画像セットで、多くの研究や論文で使われています。
0から9までの数字画像から構成されていて、
訓練画像が60,000枚、テスト画像が10,000枚入っています。

それでは、ゼロつくが提供しているソースコード
https://github.com/oreilly-japan/deep-learning-from-scratch/tree/master/ch03
からデータを読み込んでいきます。

import sys, os
sys.path.append('../deep-learning-from-scratch-master')
sys.path.append(os.pardir)
from dataset.mnist import load_mnist

↑のコードでデータを読み込むことができます。
sysとosモジュールをimportすることで、ファイルパスを設定することができます。
sys.path.appendで他の階層にあるデータを持ってくることができます。

次に、試しに以下のコードでデータを読み込めているか確認します。

(x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, normalize=False)

初回ダウンロード時にはかなり時間がかかります。
私の環境では5分ほど要しました。
HTTP Error 503: Service Unavailableというエラーコードが出ることがありますが、こちらはコードが間違っているわけではなく、サーバー上の問題です。
エラーが出てしまった場合は、時間をおいてから再度試すか、直接URLからダウンロードすることになります。

それではデータの確認も兼ねて、以下のコードで画像を表示してみます。

import numpy as np
from PIL import Image

def img_show(img):
    pil_img = Image.fromarray(np.uint8(img))
    pil_img.show()

(x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, normalize=False)
# flatten=Trueで1次元に画像を格納し、
   normalize=Falseで正規化しない設定をします。Trueの場合、0.0〜1.0の値に正規化します。

img = x_train[0]
label = t_train[0]
print(label) 
# 出力
5

print(img.shape) 
# 出力
(784,)

img = img.reshape(28, 28)  # 形状を元の画像サイズに変形

print(img.shape) 
# 出力
(28, 28)

img_show(img)

すると以下のような画像が表示されます。

ニューラルネットワークの推論処理

MNISTデータセットに対して、推論処理を行うニューラルネットワークを実装します。
入力層は784個(28×28)、出力層は10個(数字のクラス数)で構成されます。
隠れ層は2つで、1つ目は50個、2つ目は100個です。隠れ層の値は任意の数字に設定できます。

3つの関数を定義します。

import pickle
from common.functions import sigmoid, softmax

# データの取得
def get_data():
    (x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)
    return x_test, t_test

# pickleファイルに保存された学習済み重みパラメータの読み込み
def init_network():
    with open("sample_weight.pkl", 'rb') as f:
        network = pickle.load(f)
    return network

# 手書き数字から正解を予測
def predict(network, x):
    W1, W2, W3 = network['W1'], network['W2'], network['W3']
    b1, b2, b3 = network['b1'], network['b2'], network['b3']

    a1 = np.dot(x, W1) + b1
    z1 = sigmoid(a1)
    a2 = np.dot(z1, W2) + b2
    z2 = sigmoid(a2)
    a3 = np.dot(z2, W3) + b3
    y = softmax(a3)

    return y

最後に、正解率の評価をします。

x, t = get_data()
network = init_network()
accuracy_cnt = 0
for i in range(len(x)):
    y = predict(network, x[i])
    p= np.argmax(y) # 最も確率の高い要素のインデックスを取得
    if p == t[i]:
        accuracy_cnt += 1

print("Accuracy:" + str(float(accuracy_cnt) / len(x)))

# 出力
Accuracy:0.9352

正解率は0.9352となりました。

バッチ処理

データ全体を一定数ずつに分割し、それぞれ順番に処理する方法をバッチ処理と呼びます。
バッチとは、まとまった入力データのことです。
バッチ処理にすることで、処理時間を大幅に短縮し、効率良く処理することができます。
実装してみます。

def get_data():
    (x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)
    return x_test, t_test


def init_network():
    with open("sample_weight.pkl", 'rb') as f:
        network = pickle.load(f)
    return network


def predict(network, x):
    w1, w2, w3 = network['W1'], network['W2'], network['W3']
    b1, b2, b3 = network['b1'], network['b2'], network['b3']

    a1 = np.dot(x, w1) + b1
    z1 = sigmoid(a1)
    a2 = np.dot(z1, w2) + b2
    z2 = sigmoid(a2)
    a3 = np.dot(z2, w3) + b3
    y = softmax(a3)

    return y

x, t = get_data()
network = init_network()

batch_size = 100 # バッチの数
accuracy_cnt = 0

# for文で画像100枚ずつバッチとして取り出す
for i in range(0, len(x), batch_size):
    x_batch = x[i:i+batch_size]
    y_batch = predict(network, x_batch)
    p = np.argmax(y_batch, axis=1)
    accuracy_cnt += np.sum(p == t[i:i+batch_size])

print("Accuracy:" + str(float(accuracy_cnt) / len(x)))

# 出力
Accuracy:0.9352

argmax関数の引数に1を指定し、1次元目の最大値を取得しています。
また、先ほどの実装をバッチ処理にしているだけなので、正解率も同様となっています。
それほど重い処理をしているわけではありませんが、
バッチ処理なしの場合は4.9秒だった処理が、2.0秒に短縮されています。