単純なロジスティック回帰モデルは、MNISTで92%の分類精度をどのように実現しますか?


68

MNISTデータセット内のすべての画像は、同じスケールで中央に配置され、回転せずに表向きになっていますが、それらには大きな手書きのばらつきがあり、線形モデルがこのような高い分類精度をどのように実現するのか困惑しています。

私が視覚化できる限り、手書きの大きな変動を考えると、数字は784次元空間で線形に分離できないはずです。つまり、異なる数字を分離する少し複雑な(それほど複雑ではない)非線形境界があるはずです。 、正のクラスと負のクラスを線形分類器で分離できないというよく引用されたXOR例に似ています。マルチクラスロジスティック回帰が、完全に線形の特徴(多項式の特徴はない)でどのように高い精度を実現するのか、私には戸惑うようです。

例として、画像内の任意のピクセルが与えられた場合、数字23異なる手書きのバリエーションにより、そのピクセルを照らしたり、しなかったりすることができます。したがって、学習された重みのセットを使用して、各ピクセルは数字を2および3ように見せることができます。ピクセル値の組み合わせによってのみ、数字が23あるかを判断できます。これは、ほとんどの桁ペアに当てはまります。そのため、ロジスティック回帰は、ピクセル間の依存関係をまったく考慮せずに、盲目的にすべてのピクセル値に依存せずに決定を下し、そのような高い精度を達成できます。

どこか間違っているか、画像のばらつきを過大評価しているだけです。ただし、数字がどのように「ほぼ」直線的に分離できるかについての直感で誰かが私を助けることができれば素晴らしいことです。



私は興味がありました:ペナルティ付き線形モデル(つまり、glmnet)のようなものが問題に対してどれほどうまくやっているのでしょうか?思い出すと、あなたが報告しているのは、罰せられないサンプル外の精度です。
クリフAB

回答:


86

tl; drこれは画像分類データセットですが、入力から予測への直接マッピングを簡単に見つけることができる非常に簡単なタスクのままです。


回答:

これは非常に興味深い質問であり、ロジスティック回帰の単純さのおかげで、実際に答えを見つけることができます。

78478428×28

繰り返しますが、これらは重みです。

次に、上の画像を見て、最初の2桁(つまり0と1)に焦点を合わせます。青の重みは、このピクセルの強度がそのクラスに大きく寄与することを意味し、赤の値は負に寄与することを意味します。

0

1

2378

これにより、ロジスティック回帰は多くの画像を正しく取得できる可能性が非常に高いため、スコアが非常に高いことがわかります。


上記の図を再現するためのコードは少し古いですが、ここで説明します。

import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data

# Load MNIST:
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

# Create model
x = tf.placeholder(tf.float32, shape=(None, 784))
y = tf.placeholder(tf.float32, shape=(None, 10))

W = tf.Variable(tf.zeros((784,10)))
b = tf.Variable(tf.zeros((10)))
z = tf.matmul(x, W) + b

y_hat = tf.nn.softmax(z)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(y_hat), reduction_indices=[1]))
optimizer = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) # 

correct_pred = tf.equal(tf.argmax(y_hat, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# Train model
batch_size = 64
with tf.Session() as sess:

    loss_tr, acc_tr, loss_ts, acc_ts = [], [], [], []

    sess.run(tf.global_variables_initializer()) 

    for step in range(1, 1001):

        x_batch, y_batch = mnist.train.next_batch(batch_size) 
        sess.run(optimizer, feed_dict={x: x_batch, y: y_batch})

        l_tr, a_tr = sess.run([cross_entropy, accuracy], feed_dict={x: x_batch, y: y_batch})
        l_ts, a_ts = sess.run([cross_entropy, accuracy], feed_dict={x: mnist.test.images, y: mnist.test.labels})
        loss_tr.append(l_tr)
        acc_tr.append(a_tr)
        loss_ts.append(l_ts)
        acc_ts.append(a_ts)

    weights = sess.run(W)      
    print('Test Accuracy =', sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})) 

# Plotting:
for i in range(10):
    plt.subplot(2, 5, i+1)
    weight = weights[:,i].reshape([28,28])
    plt.title(i)
    plt.imshow(weight, cmap='RdBu')  # as noted by @Eric Duminil, cmap='gray' makes the numbers stand out more
    frame1 = plt.gca()
    frame1.axes.get_xaxis().set_visible(False)
    frame1.axes.get_yaxis().set_visible(False)

12
2378

13
もちろん、MNISTのサンプルは、分類器がサンプルを見る前に中央揃え、スケーリング、およびコントラスト正規化するのに役立ちます。「ゼロのエッジが実際にボックスの真ん中を通過したらどうなるのか」といった質問に対処する必要はありません。プリプロセッサはすでにすべてのゼロを同じに見えるようにするために長い道のりを進んでいるからです。
ホッブズ

1
@EricDuminilあなたの提案でスクリプトに称賛を追加しました。入力いただきありがとうございます!:D
Djib2011

1
@NitishAgarwal、この回答があなたの質問に対する回答であると思われる場合は、そのようにマークすることを検討してください。
シンタックス

11
この種の処理に興味はあるが特に慣れていない人にとっては、この答えはメカニズムの素晴らしい直感的な例を提供します。
クリリス
弊社のサイトを使用することにより、あなたは弊社のクッキーポリシーおよびプライバシーポリシーを読み、理解したものとみなされます。
Licensed under cc by-sa 3.0 with attribution required.