で、先程のエントリに関連して、MNISTデータによる手書き数字「0~9」の文字認識。 というより、これまでと同様の写経。
#!python # -*- coding: utf-8 -*- import tensorflow as tf import numpy as np import matplotlib.pyplot as plt from tensorflow.examples.tutorials.mnist import input_data def main(): np.random.seed(20170409) # MNSIST dataのdownload mnist = input_data.read_data_sets("tmp/data/", one_hot=True) x = tf.placeholder(tf.float32, [None, 784]) w = tf.Variable(tf.zeros([784, 10])) w0 = tf.Variable(tf.zeros([10])) f = tf.matmul(x, w) + w0 p = tf.nn.softmax(f) t = tf.placeholder(tf.float32, [None, 10]) # loss: 誤差関数 loss = -tf.reduce_sum(t * tf.log(p)) # train_step: トレーニングアルゴリズム train_step = tf.train.AdamOptimizer().minimize(loss) # correct_prediction: 予測値と正解値を比較し、正解or notを格納した配列 # ※1 correct_prediction = tf.equal(tf.argmax(p, 1), tf.argmax(t, 1)) # 配列である correct_prediction より、正解率を算出 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) sess = tf.InteractiveSession() # sess.run(tf.initialize_all_variables()) # for tensorflow ver0.1 sess.run( tf.global_variables_initializer() ) i = 0 for _ in range(2000): i += 1 batch_xs, batch_ts = mnist.train.next_batch(100) sess.run(train_step, feed_dict={x: batch_xs, t: batch_ts}) if i % 100 == 0: loss_val, acc_val = sess.run([loss, accuracy], feed_dict={x:mnist.test.images, t: mnist.test.labels}) print ('Step: %d, Loss: %f, Accuracy: %f' % (i, loss_val, acc_val)) if __name__ == '__main__': main()
↑こう書くと↓こう表示されます
$ python foo_2_4.py Extracting tmp/data/train-images-idx3-ubyte.gz Extracting tmp/data/train-labels-idx1-ubyte.gz Extracting tmp/data/t10k-images-idx3-ubyte.gz Extracting tmp/data/t10k-labels-idx1-ubyte.gz Step: 100, Loss: 7747.077637, Accuracy: 0.848400 Step: 200, Loss: 5439.363281, Accuracy: 0.879900 Step: 300, Loss: 4556.467773, Accuracy: 0.890900 : Step: 2000, Loss: 2848.940674, Accuracy: 0.922500 $
tf.equal(tf.argmax(p, 1), tf.argmax(t, 1)) の考え方
前回のエントリにもあるように、 正解データであるTのn行目データは、l(エル)番目のみ"1"が登録されています。 (例:"7"の画像である場合、7番目に"1"が登録)
予測関数であるPのn行目データは、P1~PKが確率である0~1の値を取ります。 例えば、"7"の画像である可能性が高い場合、P7が1に最も近い値となります。
tf.argmax()は与えられた配列の中で、最も大きな値を持つ インデックス(配列番号)を返す関数ですので、 tf.equal(tf.argmax(p, 1), tf.argmax(t, 1))とすることで正解 or not を 評価しています。
このように、全データのうち一部を取出しながら、最適化するトレーニングを「ミニバッチ」と呼ぶようです。