Tensorflow Strides引数


115

tf.nn.avg_pool、tf.nn.max_pool、tf.nn.conv2dのストライド引数を理解しようとしています。

ドキュメントは繰り返し言います

strides:長さが4以上の整数のリスト。入力テンソルの各次元のスライディングウィンドウのストライド。

私の質問は:

  1. 4以上の整数のそれぞれは何を表していますか?
  2. なぜconvnetに対してstrides [0] = strides [3] = 1にする必要があるのですか?
  3. 、この例で、私たちは見ますtf.reshape(_X,shape=[-1, 28, 28, 1])。なぜ-1?

悲しいことに、-1を使用して再形成するためのドキュメントの例は、このシナリオにうまく変換できません。

回答:


224

プーリングと畳み込み演算は、入力テンソル全体に「ウィンドウ」をスライドさせます。tf.nn.conv2d例として使用:入力テンソルに4つの次元がある場合: [batch, height, width, channels]たたみ込みは、次元の2Dウィンドウで動作しますheight, width

strides各次元でウィンドウがどれだけシフトするかを決定します。通常の使用では、最初(バッチ)と最後(深度)のストライドを1に設定します。

非常に具体的な例を使用してみましょう:32x32グレースケール入力画像に対して2次元畳み込みを実行します。入力画像はdepth = 1なので、グレースケールと言います。これは単純さを保つのに役立ちます。その画像を次のようにします。

00 01 02 03 04 ...
10 11 12 13 14 ...
20 21 22 23 24 ...
30 31 32 33 34 ...
...

1つの例(バッチサイズ= 1)に対して2x2のたたみ込みウィンドウを実行してみましょう。コンボリューションの出力チャネル深度を8にします。

畳み込みへの入力にはがありshape=[1, 32, 32, 1]ます。

で指定strides=[1,1,1,1]した場合padding=SAME、フィルターの出力は[1、32、32、8]になります。

フィルターは最初に次の出力を作成します。

F(00 01
  10 11)

そして次に:

F(01 02
  11 12)

等々。次に、2行目に移動して計算します。

F(10, 11
  20, 21)

その後

F(11, 12
  21, 22)

ストライドを[1、2、2、1]に指定すると、ウィンドウのオーバーラップは行われません。それは計算します:

F(00, 01
  10, 11)

その後

F(02, 03
  12, 13)

ストライドは、プーリングオペレーターと同様に動作します。

質問2:なぜconvnetsの[1、x、y、1]をストライドするか

最初の1つはバッチです。通常、バッチ内の例をスキップしたくないか、最初にそれらを含めるべきではありませんでした。:)

最後の1つは畳み込みの深さです。同じ理由で、通常は入力をスキップしたくありません。

conv2d演算子はより一般的です。そのため、ウィンドウを他の次元に沿ってスライドさせる畳み込みを作成できます、これはconvnetsの一般的な使用法ではありません。典型的な用途は、それらを空間的に使用することです。

-1に再形成する理由 -1は、「完全なテンソルに必要なサイズに一致するように必要に応じて調整する」というプレースホルダーです。これは、コードを入力バッチサイズから独立させるための方法です。これにより、パイプラインを変更でき、コード内のすべての場所でバッチサイズを調整する必要がなくなります。


5
@derek(テキストから)「コンボリューションに8の出力チャネル深度を与える」ため。これは、畳み込みを設定するときに選ぶことができるものだ、と回答は8を選んだ
etarion

17

入力は4次元であり、形式は次のとおりです。 [batch_size, image_rows, image_cols, number_of_colors]

ストライドは、一般に、操作を適用する間の重複を定義します。conv2dの場合、これは、たたみ込みフィルターの連続するアプリケーション間の距離を指定します。特定の次元の値1は、すべての行/列に演算子を適用することを意味し、値2は毎秒を意味する、というようになります。

Re 1)たたみ込みで重要な値は2番目と3番目で、行と列に沿ったたたみ込みフィルターの適用における重複を表します。[1、2、2、1]の値は、2番目の行と列ごとにフィルターを適用することを示しています。

Re 2)技術的な制限は知りません(CuDNNの要件である可能性があります)が、通常、行または列の次元に沿ってストライドを使用します。バッチサイズを超えることは必ずしも意味がありません。最後の次元がわからない。

Re 3)いずれかの次元に-1を設定することは、「テンソルの要素の総数が変更されないように最初の次元の値を設定する」ことを意味します。この場合、-1はbatch_sizeと等しくなります。


11

1次元の場合にストライドが行うことから始めましょう。

あなたinput = [1, 0, 2, 3, 0, 1, 1]kernel = [2, 1, 3]、たたみ込みの結果がであると仮定しましょう[8, 11, 7, 9, 4]。これは、カーネルを入力にスライドさせ、要素ごとの乗算を実行し、すべてを合計することによって計算されます。このように

  • 8 = 1 * 2 + 0 * 1 + 2 * 3
  • 11 = 0 * 2 + 2 * 1 + 3 * 3
  • 7 = 2 * 2 + 3 * 1 + 0 * 3
  • 9 = 3 * 2 + 0 * 1 + 1 * 3
  • 4 = 0 * 2 + 1 * 1 + 1 * 3

ここでは1要素ずつスライドしますが、他の数値を使用しても何も問題はありません。この数はあなたの歩幅です。s番目の結果をすべて取得するだけで、1ストリーデッドコンボリューションの結果をダウンサンプリングすると考えることができます。

入力サイズi、カーネルサイズk、ストライドsおよびパディングpがわかれば、たたみ込みの出力サイズは次のように簡単に計算できます。

ここに画像の説明を入力してください

ここ|| オペレーターは天井操作を意味します。プーリング層の場合、s = 1。


N-dimケース。

1 dimケースの数学を知ると、各dimが独立していることがわかると、n dimケースは簡単です。つまり、各次元を個別にスライドさせるだけです。以下は2-dの例です。すべての次元で同じストライドである必要はないことに注意してください。したがって、N-dim入力/カーネルの場合、Nストライドを提供する必要があります。


これで、すべての質問に簡単に答えることができます。

  1. 4以上の整数のそれぞれは何を表していますか?conv2dpoolは、このリストが各次元間のストライドを表すことを示しています。ストライドリストの長さがカーネルテンソルのランクと同じであることに注意してください。
  2. なぜconnetのストライド[0] =ストライド3 = 1が必要なのですか?。最初のディメンションはバッチサイズ、最後のディメンションはチャネルです。バッチもチャネルもスキップする意味はありません。だからあなたはそれらを1にします。幅/高さについては何かをスキップすることができ、それが1ではないかもしれない理由です。
  3. tf.reshape(_X、shape = [-1、28、28、1])。なぜ-1? tf.reshapeはそれをカバーしています:

    形状の1つのコンポーネントが特別な値-1の場合、そのサイズは合計サイズが一定になるように計算されます。特に、[-1]の形状は1-Dに平坦化されます。シェイプの最大1つのコンポーネントは-1です。


2

@dgaは説明する素晴らしい仕事をしてくれました、そしてそれがどれほど役に立ったかについて私は十分に感謝することができません。同様に、私はどのように私の発見を共有したいと思いますstride 3D畳み込み機能するます。

conv3dのTensorFlowドキュメントによると、入力の形状は次の順序である必要があります。

[batch, in_depth, in_height, in_width, in_channels]

右端から左端までの変数を例を使って説明しましょう。入力形状が input_shape = [1000,16,112,112,3]

input_shape[4] is the number of colour channels (RGB or whichever format it is extracted in)
input_shape[3] is the width of the image
input_shape[2] is the height of the image
input_shape[1] is the number of frames that have been lumped into 1 complete data
input_shape[0] is the number of lumped frames of images we have.

以下は、ストライドの使用方法に関する要約ドキュメントです。

strides:長さが5以上の整数のリスト。長さ5の1次元テンソル。入力の各次元のスライディングウィンドウのストライド。持つ必要がありますstrides[0] = strides[4] = 1

多くの作品で示されているように、ストライドは、ウィンドウまたはカーネルがデータフレームまたはピクセルであっても、最も近い要素からジャンプする距離を意味します(これは、言い方を変えます)。

上記のドキュメントから、3Dのストライドは次のようになります=(1、XYZ、1)のようになります。

ドキュメントはそれを強調していstrides[0] = strides[4] = 1ます。

strides[0]=1 means that we do not want to skip any data in the batch 
strides[4]=1 means that we do not want to skip in the channel 

strides [X]は、集中フレームで行うべきスキップの数を意味します。たとえば、16個のフレームがある場合、X = 1はすべてのフレームを使用することを意味します。X = 2は、1秒ごとのフレームを使用することを意味し、それが継続します

strides [y]とstrides [z]は、@ dgaの説明に従いますので、その部分はやり直しません。

ただし、kerasでは、3つの整数のタプル/リストを指定するだけでよく、各空間次元に沿った畳み込みのストライドを指定します。空間次元はstride [x]、strides [y]およびstrides [z]です。strides [0]とstrides [4]はすでにデフォルトで1になっています。

私は誰かがこれが役に立ったと思います!

弊社のサイトを使用することにより、あなたは弊社のクッキーポリシーおよびプライバシーポリシーを読み、理解したものとみなされます。
Licensed under cc by-sa 3.0 with attribution required.