KerasでLSTMまたはGRUをトレーニングすると、突然精度が低下する


8

私のリカレントニューラルネットワーク(LSTM、またはGRU)は、私が説明できない方法で動作します。トレーニングが開始され、突然精度が低下する(そして、損失が急速に増加する)ときに、トレーニングとテストの両方のメトリックが適切にトレーニングされます(結果はかなり良く見えます)。時々、ネットはおかしくなり、ランダムな出力を返し、時々(与えられた3つの例の最後のように)、すべての入力に同じ出力を返し始めます。

画像

あなたが持っています。この動作のための任意の説明を?どんな意見でも大歓迎です。以下のタスクの説明と図を参照してください。

タスク:単語からword2vecベクトルを予測する 入力:独自のword2vecモデル(正規化済み)があり、ネットワークに単語(文字で文字)を入力します。単語にパディングします(下の例を参照)。 例:フットボールという単語があり、100次元幅のword2vecベクトルを予測したいとします。次に、入力は$football$$$$$$$$$$です。

動作の3つの例:

単層LSTM

model = Sequential([
    LSTM(1024, input_shape=encoder.shape, return_sequences=False),
    Dense(w2v_size, activation="linear")
])

model.compile(optimizer='adam', loss="mse", metrics=["accuracy"])

画像

単層GRU

model = Sequential([
    GRU(1024, input_shape=encoder.shape, return_sequences=False),
    Dense(w2v_size, activation="linear")
])

model.compile(optimizer='adam', loss="mse", metrics=["accuracy"])

画像

二重層LSTM

model = Sequential([
    LSTM(512, input_shape=encoder.shape, return_sequences=True),
    TimeDistributed(Dense(512, activation="sigmoid")),
    LSTM(512, return_sequences=False),
    Dense(256, activation="tanh"),
    Dense(w2v_size, activation="linear")
])

model.compile(optimizer='adam', loss="mse", metrics=["accuracy"])

画像

また、以前に同様のアーキテクチャを使用した別のプロジェクトでこの種の動作を経験しましたが、その目的とデータは異なりました。したがって、理由はデータや特定の目的に隠すべきではなく、アーキテクチャに隠すべきです。


問題の原因を見つけましたか?
アントワーヌ

残念ながらそうではありません。別のアーキテクチャに変更したところ、これに戻る機会がありませんでした。しかし、いくつかの手掛かりがあります。私たちの推測では、何らかの原因で1つ以上のパラメータがに変更されたと考えられnanます。
Marek、

nanパラメータはナン以外の損失にはなりません。私の推測では、あなたの勾配はたまたま爆発し、非バッチ正規化ネットワークでも同様のことが起こりました。
ルギ

これは、TensorBoardを使用して調査しようとしたことの1つでもありますが、私たちの場合、勾配爆発は証明されていません。アイデアのnan1つは計算の1つに現れ、それからデフォルトで別の値になり、ネットワークがおかしくなりました。しかし、それは単なる乱暴な推測です。ご意見ありがとうございます。
Marek

回答:


2

問題を特定するための私の提案は次のとおりです。

1)トレーニングの学習曲線を見てください:列車の学習曲線はどのようになっていますか?トレーニングセットを学習しますか?そうでない場合は、まずトレーニングセットに適合できることを確認するために作業します。

2)データをチェックして、NaNがないことを確認します(トレーニング、検証、テスト)。

3)勾配と重みをチェックして、NaNがないことを確認します。

4)学習率を下げて、急な大きな更新が急激な最小値に詰まったためではないことを確認してください。

5)すべてが正しいことを確認するには、ネットワークの予測をチェックして、ネットワークが一定の予測や繰り返しの予測を行わないようにします。

6)バッチ内のデータがすべてのクラスに関してバランスが取れているかどうかを確認します。

7)データを平均単位分散がゼロになるように正規化します。同様に重みを初期化します。トレーニングを支援します。

弊社のサイトを使用することにより、あなたは弊社のクッキーポリシーおよびプライバシーポリシーを読み、理解したものとみなされます。
Licensed under cc by-sa 3.0 with attribution required.