カスタムトレーニングループを使用しているcollections.deque
場合は、追加できる「ローリング」リストであるを使用できますmaxlen
。リストが。より長い場合、左側のアイテムがポップアウトされます。行は次のとおりです。
loss_history = deque(maxlen=early_stopping + 1)
for epoch in range(epochs):
fit(epoch)
loss_history.append(test_loss.result().numpy())
if len(loss_history) > early_stopping and loss_history.popleft() < min(loss_history)
break
完全な例を次に示します。
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow_datasets as tfds
import tensorflow as tf
from tensorflow.keras.layers import Dense
from collections import deque
data, info = tfds.load('iris', split='train', as_supervised=True, with_info=True)
data = data.map(lambda x, y: (tf.cast(x, tf.int32), y))
train_dataset = data.take(120).batch(4)
test_dataset = data.skip(120).take(30).batch(4)
model = tf.keras.models.Sequential([
Dense(8, activation='relu'),
Dense(16, activation='relu'),
Dense(info.features['label'].num_classes)])
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
train_loss = tf.keras.metrics.Mean()
test_loss = tf.keras.metrics.Mean()
train_acc = tf.keras.metrics.SparseCategoricalAccuracy()
test_acc = tf.keras.metrics.SparseCategoricalAccuracy()
opt = tf.keras.optimizers.Adam(learning_rate=1e-3)
@tf.function
def train_step(inputs, labels):
with tf.GradientTape() as tape:
logits = model(inputs, training=True)
loss = loss_object(labels, logits)
gradients = tape.gradient(loss, model.trainable_variables)
opt.apply_gradients(zip(gradients, model.trainable_variables))
train_loss(loss)
train_acc(labels, logits)
@tf.function
def test_step(inputs, labels):
logits = model(inputs, training=False)
loss = loss_object(labels, logits)
test_loss(loss)
test_acc(labels, logits)
def fit(epoch):
template = 'Epoch {:>2} Train Loss {:.3f} Test Loss {:.3f} ' \
'Train Acc {:.2f} Test Acc {:.2f}'
train_loss.reset_states()
test_loss.reset_states()
train_acc.reset_states()
test_acc.reset_states()
for X_train, y_train in train_dataset:
train_step(X_train, y_train)
for X_test, y_test in test_dataset:
test_step(X_test, y_test)
print(template.format(
epoch + 1,
train_loss.result(),
test_loss.result(),
train_acc.result(),
test_acc.result()
))
def main(epochs=50, early_stopping=10):
loss_history = deque(maxlen=early_stopping + 1)
for epoch in range(epochs):
fit(epoch)
loss_history.append(test_loss.result().numpy())
if len(loss_history) > early_stopping and loss_history.popleft() < min(loss_history):
print(f'\nEarly stopping. No validation loss '
f'improvement in {early_stopping} epochs.')
break
if __name__ == '__main__':
main(epochs=250, early_stopping=10)
Epoch 1 Train Loss 1.730 Test Loss 1.449 Train Acc 0.33 Test Acc 0.33
Epoch 2 Train Loss 1.405 Test Loss 1.220 Train Acc 0.33 Test Acc 0.33
Epoch 3 Train Loss 1.173 Test Loss 1.054 Train Acc 0.33 Test Acc 0.33
Epoch 4 Train Loss 1.006 Test Loss 0.935 Train Acc 0.33 Test Acc 0.33
Epoch 5 Train Loss 0.885 Test Loss 0.846 Train Acc 0.33 Test Acc 0.33
...
Epoch 89 Train Loss 0.196 Test Loss 0.240 Train Acc 0.89 Test Acc 0.87
Epoch 90 Train Loss 0.195 Test Loss 0.239 Train Acc 0.89 Test Acc 0.87
Epoch 91 Train Loss 0.195 Test Loss 0.239 Train Acc 0.89 Test Acc 0.87
Epoch 92 Train Loss 0.194 Test Loss 0.239 Train Acc 0.90 Test Acc 0.87
Early stopping. No validation loss improvement in 10 epochs.