KerasのEarly Stoppingコールバックで使用されるメトリックを変更する方法はありますか?


12

KerasトレーニングでEarly Stoppingコールバックを使用すると、一部のメトリック(通常は検証の損失)が増加しないときに停止します。検証損失の代わりに別のメトリック(精度、再現率、fメジャーなど)を使用する方法はありますか?これまでに見たすべての例は、次の例に似ています:callbacks.EarlyStopping(monitor = 'val_loss'、patience = 5、verbose = 0、mode = 'auto')

回答:


10

モデルのコンパイル時に指定した任意のメトリック関数を使用できます。

次のメトリック関数があるとします。

def my_metric(y_true, y_pred):
     return some_metric_computation(y_true, y_pred)

この関数の唯一の要件は、真のyと予測されたyを受け入れることです。

モデルをコンパイルするときは、「精度」などの組み込みメトリックを指定する方法と同様に、このメトリックを指定します。

model.compile(metrics=['accuracy', my_metric], ...)

関数名my_metricを ''なしで使用していることに注意してください( 'accuracy'のビルドとは対照的です)。

次に、EarlyStoppingを定義する場合は、関数の名前を使用します(今回は ''を付けます)。

EarlyStopping(monitor='my_metric', mode='min')

必ずモードを指定してください(低いほど良い場合は最小、高いほど良い場合は最大)。

他の組み込み指標と同じように使用できます。これは、ModelCheckpointなどの他のコールバックでも機能します(ただし、テストしていません)。内部的には、Kerasは関数名を使用して、このモデルで使用可能なメトリックのリストに新しいメトリックを追加するだけです。

model.fit(...)で検証用のデータを指定する場合、 'val_my_metric'を使用してEarlyStoppingにそれを使用することもできます。


3

もちろん、自分で作成してください!

class EarlyStopByF1(keras.callbacks.Callback):
    def __init__(self, value = 0, verbose = 0):
        super(keras.callbacks.Callback, self).__init__()
        self.value = value
        self.verbose = verbose


    def on_epoch_end(self, epoch, logs={}):
         predict = np.asarray(self.model.predict(self.validation_data[0]))
         target = self.validation_data[1]
         score = f1_score(target, prediction)
         if score > self.value:
            if self.verbose >0:
                print("Epoch %05d: early stopping Threshold" % epoch)
            self.model.stop_training = True


callbacks = [EarlyStopByF1(value = .90, verbose =1)]
model.fit(X, y, batch_size = 32, nb_epoch=nb_epoch, verbose = 1, 
validation_data(X_val,y_val), callbacks=callbacks)

私はこれをテストしていませんが、それはあなたがそれをどのように進めるかについての一般的な味であるべきです。うまくいかない場合はお知らせください。週末に再試行します。また、独自のf1スコアがすでに実装されていることも前提としています。ない場合は、sklearnにインポートします。


+1 2020
オースティン
弊社のサイトを使用することにより、あなたは弊社のクッキーポリシーおよびプライバシーポリシーを読み、理解したものとみなされます。
Licensed under cc by-sa 3.0 with attribution required.