end0tknr's kipple - web写経開発

太宰府天満宮の狛犬って、妙にカワイイ

MNISTデータによる手書き数字「0~9」の文字認識 (deep learning & python)

で、先程のエントリに関連して、MNISTデータによる手書き数字「0~9」の文字認識。 というより、これまでと同様の写経。

#!python
# -*- coding: utf-8 -*-
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data

def main():
    np.random.seed(20170409)

    # MNSIST dataのdownload
    mnist = input_data.read_data_sets("tmp/data/", one_hot=True)


x = tf.placeholder(tf.float32, [None, 784])
    w = tf.Variable(tf.zeros([784, 10]))
    w0 = tf.Variable(tf.zeros([10]))
    f = tf.matmul(x, w) + w0
    p = tf.nn.softmax(f)


    t = tf.placeholder(tf.float32, [None, 10])
    # loss: 誤差関数
    loss = -tf.reduce_sum(t * tf.log(p))
    # train_step: トレーニングアルゴリズム
    train_step = tf.train.AdamOptimizer().minimize(loss)
    # correct_prediction: 予測値と正解値を比較し、正解or notを格納した配列
    # ※1
    correct_prediction = tf.equal(tf.argmax(p, 1), tf.argmax(t, 1))
    # 配列である correct_prediction より、正解率を算出
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    
    sess = tf.InteractiveSession()
#    sess.run(tf.initialize_all_variables()) # for tensorflow ver0.1
    sess.run( tf.global_variables_initializer() )

    i = 0
    for _ in range(2000):
        i += 1
        batch_xs, batch_ts = mnist.train.next_batch(100)
        sess.run(train_step, feed_dict={x: batch_xs, t: batch_ts})
        if i % 100 == 0:
            loss_val, acc_val = sess.run([loss, accuracy],
                feed_dict={x:mnist.test.images, t: mnist.test.labels})
            print ('Step: %d, Loss: %f, Accuracy: %f'
                   % (i, loss_val, acc_val))



if __name__ == '__main__':
    main()

↑こう書くと↓こう表示されます

$ python foo_2_4.py 
Extracting tmp/data/train-images-idx3-ubyte.gz
Extracting tmp/data/train-labels-idx1-ubyte.gz
Extracting tmp/data/t10k-images-idx3-ubyte.gz
Extracting tmp/data/t10k-labels-idx1-ubyte.gz
Step: 100, Loss: 7747.077637, Accuracy: 0.848400
Step: 200, Loss: 5439.363281, Accuracy: 0.879900
Step: 300, Loss: 4556.467773, Accuracy: 0.890900
  :
Step: 2000, Loss: 2848.940674, Accuracy: 0.922500
$ 

tf.equal(tf.argmax(p, 1), tf.argmax(t, 1)) の考え方

前回のエントリにもあるように、 正解データであるTのn行目データは、l(エル)番目のみ"1"が登録されています。 (例:"7"の画像である場合、7番目に"1"が登録)

 \displaystyle \large
T_n = 
    \begin{pmatrix}
      t_{1n}  &t_{2n} &\ldots &t_{Kn}
    \end{pmatrix}

予測関数であるPのn行目データは、P1~PKが確率である0~1の値を取ります。 例えば、"7"の画像である可能性が高い場合、P7が1に最も近い値となります。

 \displaystyle \large
P_n = 
    \begin{pmatrix}
      P_1 (x_n)  &P_2 (x_n) &\ldots  &P_K (x_n)
    \end{pmatrix}

tf.argmax()は与えられた配列の中で、最も大きな値を持つ インデックス(配列番号)を返す関数ですので、 tf.equal(tf.argmax(p, 1), tf.argmax(t, 1))とすることで正解 or not を 評価しています。

このように、全データのうち一部を取出しながら、最適化するトレーニングを「ミニバッチ」と呼ぶようです。