回答:
早期停止とは、基本的には、損失が増加し始めたら(つまり、検証の精度が低下し始めたら)トレーニングを停止することです。文書によると、それは次のように使用されます。
keras.callbacks.EarlyStopping(monitor='val_loss',
min_delta=0,
patience=0,
verbose=0, mode='auto')
値は実装(問題、バッチサイズなど)によって異なりますが、一般的には、過剰適合を防ぐために使用します。
monitor
引数をに設定して、検証の損失を監視します(相互検証または少なくともトレーニング/テストセットを使用する必要があります)'val_loss'
。min_delta
あるエポックでの損失を改善として定量化するかどうかのしきい値です。損失の差が未満の場合、min_delta
改善なしとして定量化されます。損失が悪化する時期に関心があるため、0のままにしておくことをお勧めします。patience
引数は、損失が増加し始めた(改善が止まった)時点で停止するまでのエポック数を表します。これは、実装に依存します。非常に小さなバッチ
または大きな学習率を使用する場合、損失ジグザグ(精度はよりノイズが大きくなります)なので、大きなpatience
引数を設定することをお勧めします。大きなバッチを使用し、学習率が低い場合、損失はより滑らかになるため、より小さなpatience
引数を使用できます。どちらの方法でも、2のままにしておくと、モデルにより多くのチャンスが与えられます。verbose
何を印刷するかを決定し、デフォルト(0)のままにします。mode
引数は監視する量の方向に依存します(減少するか増加するか)、損失を監視するため、を使用できますmin
。しかし、kerasに処理を任せておき、それをauto
したがって、私はこのようなものを使用し、早期に停止した場合と停止していない場合のエラー損失をプロットして実験します。
keras.callbacks.EarlyStopping(monitor='val_loss',
min_delta=0,
patience=2,
verbose=0, mode='auto')
コールバックがどのように機能するかについて曖昧さをなくすために、もっと詳しく説明しようと思います。fit(... callbacks=[es])
モデルを呼び出すと、Kerasは所定のコールバックオブジェクトに所定の関数を呼び出します。これらの関数を呼び出すことができon_train_begin
、on_train_end
、on_epoch_begin
、on_epoch_end
とon_batch_begin
、on_batch_end
。早期停止コールバックはすべてのエポックエンドで呼び出され、最適な監視値を現在の値と比較し、条件が満たされた場合に停止します(最適な監視値の観測以降に経過したエポックの数、および忍耐引数よりも多く、最後の値がmin_deltaよりも大きいなど)。
コメントの@BrentFaustで指摘されているように、モデルのトレーニングは、早期停止条件が満たされるか、epochs
パラメーター(デフォルト= 10)fit()
が満たされるまで続行されます。Early Stoppingコールバックを設定しても、モデルはそのepochs
パラメーターを超えてトレーニングしません。したがってfit()
、より大きなepochs
値で関数を呼び出すと、早期停止コールバックのメリットが大きくなります。
callbacks=[EarlyStopping(patience=2)]
エポックが与えられない限り、それは効果がないことに注意してくださいmodel.fit(..., epochs=max_epochs)
。
epoch=1
forループ(さまざまな使用例)でfitを呼び出して、このコールバックが失敗する場合があることに気づきました。私の答えに曖昧さがある場合、私はそれをより良い方法で表現しようとします。
restore_best_weights
、トレーニング後に最適な重みでモデルをロードする引数(まだドキュメントにはありません)を使用できます。しかし、あなたの目的のためにModelCheckpoint
、save_best_only
引数付きのコールバックを使用します。ドキュメントを確認できます。使用するのは簡単ですが、トレーニング後に最適なウェイトを手動でロードする必要があります。
min_delta
は、監視値の変化を改善として定量化するかどうかのしきい値です。したがって、そうです、私たちが与えるmonitor = 'val_loss'
場合、それは現在の検証の損失と以前の検証の損失の違いを指します。実際には、min_delta=0.1
検証損失の減少(現在-以前)を0.1未満にしても定量化できず、トレーニングが停止します(ある場合patience = 0
)。