tf.nn.dynamic_rnn()の出力は何ですか?


8

私は公式文書から私が何を理解しているかについて確信がありません、それは言う:

戻り値:ペア(出力、状態)ここで:

outputs:RNN出力テンソル。

time_major == False(デフォルト)の場合、これはTensorシェイプになります: [batch_size, max_time, cell.output_size]

の場合time_major == True、これはTensorシェイプになります[max_time, batch_size, cell.output_size]

場合注は、cell.output_size整数またはTensorShapeオブジェクトの(おそらくネスト)タプルは、次に、出力タプルが、cell.output_sizeと同じ構造を有するにおける形状データに対応する形状を有するテンソルを含むであろうcell.output_size

state:最終状態。cell.state_sizeがintの場合、これはShapedになります[batch_size, cell.state_size]。TensorShapeの場合、これは整形され[batch_size] + cell.state_sizeます。それが(おそらくネストされた)intまたはTensorShapeのタプルである場合、これは対応する形状を持つタプルになります。セルがLSTMCellsの場合、状態は各セルのLSTMStateTupleを含むタプルになります。

であるoutput[-1]は常に(RNN、GRU、LSTMすなわち3つのすべての細胞型において)状態に(リターンタプルの2番目の要素)を等しく?どこにでもある文献は、隠された状態という用語の使用においては自由主義的すぎると思います。3つすべてのセルの非表示状態がスコアになりますか?

回答:


10

はい、セル出力は非表示状態と同じです。LSTMの場合、次の図に示すように、これはタプル(の2番目の要素LSTMStateTuple)の短期間の部分です。

LSTM

ただしtf.nn.dynamic_rnn、の場合、シーケンスが短い(引数)と、返される状態が異なる場合がありますsequence_length。この例を見てください:

n_steps = 2
n_inputs = 3
n_neurons = 5

X = tf.placeholder(dtype=tf.float32, shape=[None, n_steps, n_inputs])
seq_length = tf.placeholder(tf.int32, [None])

basic_cell = tf.nn.rnn_cell.BasicRNNCell(num_units=n_neurons)
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, sequence_length=seq_length, dtype=tf.float32)

X_batch = np.array([
  # t = 0      t = 1
  [[0, 1, 2], [9, 8, 7]], # instance 0
  [[3, 4, 5], [0, 0, 0]], # instance 1
  [[6, 7, 8], [6, 5, 4]], # instance 2
  [[9, 0, 1], [3, 2, 1]], # instance 3
])
seq_length_batch = np.array([2, 1, 2, 2])

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  outputs_val, states_val = sess.run([outputs, states], 
                                     feed_dict={X: X_batch, seq_length: seq_length_batch})

  print(outputs_val)
  print()
  print(states_val)

ここで、入力バッチには4つのシーケンスが含まれており、そのうちの1つは短く、ゼロが埋め込まれています。実行すると、次のようになります。

[[[ 0.2315362  -0.37939444 -0.625332   -0.80235624  0.2288385 ]
  [ 0.9999524   0.99987394  0.33580178 -0.9981791   0.99975705]]

 [[ 0.97374666  0.8373545  -0.7455188  -0.98751736  0.9658986 ]
  [ 0.          0.          0.          0.          0.        ]]

 [[ 0.9994331   0.9929737  -0.8311569  -0.99928087  0.9990415 ]
  [ 0.9984355   0.9936006   0.3662448  -0.87244385  0.993848  ]]

 [[ 0.9962312   0.99659646  0.98880637  0.99548346  0.9997809 ]
  [ 0.9915743   0.9936939   0.4348318   0.8798458   0.95265496]]]

[[ 0.9999524   0.99987394  0.33580178 -0.9981791   0.99975705]
 [ 0.97374666  0.8373545  -0.7455188  -0.98751736  0.9658986 ]
 [ 0.9984355   0.9936006   0.3662448  -0.87244385  0.993848  ]
 [ 0.9915743   0.9936939   0.4348318   0.8798458   0.95265496]]

...確かにstate == output[1]、完全なシーケンスとstate == output[0]短いシーケンスの両方を示しています。またoutput[1]、このシーケンスのゼロベクトルです。LSTMおよびGRUセルについても同様です。

したがって、これはゼロを無視しstateて、最後の実際の RNN状態を保持する便利なテンソルです。outputテンソルは、出力の保持している全ての細胞を、それがゼロを無視しません。それが両方を返す理由です。


2

/programming/36817596/get-last-output-of-dynamic-rnn-in-tensorflow/49705930#49705930の可能なコピー

とにかく答えを先に進めましょう。

このコードの切り取りは、dynamic_rnnレイヤーによって実際に返されているものを理解するのに役立つ場合があります

=> (outputs、final_output_state)のタプル。

したがって 、T シーケンスの最大シーケンス長の入力の場合、出力[Batch_size, T, num_inputs](与えられたtime_major= False;既定値)の形状であり、各タイムステップでの出力状態が含まれますh1, h2.....hT

そして、final_output_stateは形状[Batch_size,num_inputs]であり、各バッチシーケンスの最終セル状態cTと出力状態hTを持っています。

しかし、dynamic_rnnが使用されているので、私の推測では、シーケンスの長さはバッチごとに異なります。

    import tensorflow as tf
    import numpy as np
    from tensorflow.contrib import rnn
    tf.reset_default_graph()

    # Create input data
    X = np.random.randn(2, 10, 8)

    # The second example is of length 6 
    X[1,6:] = 0
    X_lengths = [10, 6]

    cell = tf.nn.rnn_cell.LSTMCell(num_units=64, state_is_tuple=True)

    outputs, states  = tf.nn.dynamic_rnn(cell=cell,
                                         dtype=tf.float64,
                                         sequence_length=X_lengths,
                                         inputs=X)

    result = tf.contrib.learn.run_n({"outputs": outputs, "states":states},
                                    n=1,
                                    feed_dict=None)
    assert result[0]["outputs"].shape == (2, 10, 64)
    print result[0]["outputs"].shape
    print result[0]["states"].h.shape
    # the final outputs state and states returned must be equal for each      
    # sequence
    assert(result[0]["outputs"][0][-1]==result[0]["states"].h[0]).all()
    assert(result[0]["outputs"][-1][5]==result[0]["states"].h[-1]).all()
    assert(result[0]["outputs"][-1][-1]==result[0]["states"].h[-1]).all()

2番目のシーケンスの最終状態は6番目のタイムステップにあるため、最終的なアサーションは失敗します。インデックス5と[6:9]からの残りの出力は、2番目のタイムステップですべて0です。

弊社のサイトを使用することにより、あなたは弊社のクッキーポリシーおよびプライバシーポリシーを読み、理解したものとみなされます。
Licensed under cc by-sa 3.0 with attribution required.