アテンションメカニズムとは何ですか?


23

ここ数年、さまざまなディープラーニングの論文で注意メカニズムが使用されてきました。Open AIの研究責任者であるIlya Sutskever氏は、熱心に称賛しています:https ://towardsdatascience.com/the-fall-of-rnn-lstm-2d1594c74ce0

パデュー大学のEugenio Culurcielloは、純粋に注意ベースのニューラルネットワークを優先して、RNNとLSTMを放棄すべきだと主張しています。

https://towardsdatascience.com/the-fall-of-rnn-lstm-2d1594c74ce0

これは誇張のように見えますが、純粋に注意に基づくモデルがシーケンスモデリングタスクで非常にうまく機能していることは否定できません。

ただし、注意ベースのモデルとは正確には何ですか?そのようなモデルの明確な説明をまだ見つけていません。履歴値を与えられた多変量時系列の新しい値を予測したいとします。LSTMセルを持つRNNでそれを行う方法は非常に明確です。アテンションベースのモデルで同じことをどのように行うのでしょうか?

回答:


20

多くの場合、ルックアップベクトルuを使用して、ベクトルviセットを1つのベクトルに集約する方法に注意してください。通常、v iはモデルへの入力または前のタイムステップの非表示状態、または1レベル下の非表示状態(スタックLSTMの場合)です。uvi

結果は、現在のタイムステップに関連するコンテキストを含むため、コンテキストベクトルcと呼ばれることがよくあります。

この追加のコンテキストベクトルcは、RNN / LSTMにも供給されます(元の入力と単純に連結できます)。したがって、コンテキストを使用して予測を支援できます。

これを行う最も簡単な方法は、確率ベクトルp=softmax(VTu)およびc=ipiviを計算することです。ここで、Vは以前のすべてのvi連結です。一般的なルックアップベクトルuは、現在の非表示状態htです。

viTuf(vi,u)f

シーケンス間モデルの一般的なアテンションメカニズムは、p=softmax(qTtanh(W1vi+W2ht))。ここで、vはエンコーダーの非表示状態、htは現在の非表示状態です。デコーダーの。qおよび両方のWはパラメーターです。

注意のアイデアに関するさまざまなバリエーションを示すいくつかの論文:

ポインタネットワークは、組み合わせの最適化問題を解決するために、参照入力に注意を払います。

リカレントエンティティネットワークは、テキストの読み取り中に異なるエンティティ(人/オブジェクト)の個別のメモリ状態を維持し、注意を使用して正しいメモリ状態を更新します。

また、変圧器モデルは注意を大いに活用します。それらのアテンションの定式化はやや一般的であり、キーベクトルkiも含みます。アテンションの重みpは実際にキーとルックアップの間で計算され、コンテキストはvi構築されます。


簡単なテストに合格したという事実以外に正確性を保証することはできませんが、ここに注意の1つの形式の簡単な実装があります。

基本的なRNN:

def rnn(inputs_split):
    bias = tf.get_variable('bias', shape = [hidden_dim, 1])
    weight_hidden = tf.tile(tf.get_variable('hidden', shape = [1, hidden_dim, hidden_dim]), [batch, 1, 1])
    weight_input = tf.tile(tf.get_variable('input', shape = [1, hidden_dim, in_dim]), [batch, 1, 1])

    hidden_states = [tf.zeros((batch, hidden_dim, 1), tf.float32)]
    for i, input in enumerate(inputs_split):
        input = tf.reshape(input, (batch, in_dim, 1))
        last_state = hidden_states[-1]
        hidden = tf.nn.tanh( tf.matmul(weight_input, input) + tf.matmul(weight_hidden, last_state) + bias )
        hidden_states.append(hidden)
    return hidden_states[-1]

注意を払って、新しい非表示状態が計算される前に数行だけを追加します。

        if len(hidden_states) > 1:
            logits = tf.transpose(tf.reduce_mean(last_state * hidden_states[:-1], axis = [2, 3]))
            probs = tf.nn.softmax(logits)
            probs = tf.reshape(probs, (batch, -1, 1, 1))
            context = tf.add_n([v * prob for (v, prob) in zip(hidden_states[:-1], tf.unstack(probs, axis = 1))])
        else:
            context = tf.zeros_like(last_state)

        last_state = tf.concat([last_state, context], axis = 1)

        hidden = tf.nn.tanh( tf.matmul(weight_input, input) + tf.matmul(weight_hidden, last_state) + bias )

完全なコード


p=softmax(VTu)ic=ipivipiVTvVTv

1
zi=viTup=softmax(z)pi=eizjejz

ppi

1
はい、私は何を意味するのかということ
島尾

@shimao チャットルームを作成しました。(この質問についてではなく)話したいと思うかどうかをお知らせください
DeltaIV
弊社のサイトを使用することにより、あなたは弊社のクッキーポリシーおよびプライバシーポリシーを読み、理解したものとみなされます。
Licensed under cc by-sa 3.0 with attribution required.