Author: nkovela1
Date created: 2022/09/19
Last modified: 2022/09/26
Description: Guide on how to share a custom training step across multiple Keras models.
This example shows how to create a custom training step using the "Trainer pattern",
which can then be shared across multiple Keras models. This pattern overrides the
train_step()
method of the keras.Model
class, allowing for training loops
beyond plain supervised learning.
The Trainer pattern can also easily be adapted to more complex models with larger custom training steps, such as this end-to-end GAN model, by putting the custom training step in the Trainer class definition.
import tensorflow as tf
from tensorflow import keras
# Load MNIST dataset and standardize the data
mnist = keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
A custom training step can be created by overriding the train_step()
method of a Model subclass:
class MyTrainer(keras.Model):
def __init__(self, model):
super().__init__()
self.model = model
def train_step(self, data):
x, y = data
with tf.GradientTape() as tape:
y_pred = self.model(x, training=True) # Forward pass
# Compute loss value
loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
# Compute gradients
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Update metrics configured in `compile()`
self.compiled_metrics.update_state(y, y_pred)
# Return a dict mapping metric names to current value.
return {m.name: m.result() for m in self.metrics}
def call(self, x):
# Equivalent to `call()` of the wrapped keras.Model
x = self.model(x)
return x
Let's define two different models that can share our Trainer class and its custom train_step()
:
# A model defined using Sequential API
model_a = keras.models.Sequential(
[
keras.layers.Flatten(input_shape=(28, 28)),
keras.layers.Dense(256, activation="relu"),
keras.layers.Dropout(0.2),
keras.layers.Dense(10, activation="softmax"),
]
)
# A model defined using Functional API
func_input = keras.Input(shape=(28, 28, 1))
x = keras.layers.Flatten(input_shape=(28, 28))(func_input)
x = keras.layers.Dense(512, activation="relu")(x)
x = keras.layers.Dropout(0.4)(x)
func_output = keras.layers.Dense(10, activation="softmax")(x)
model_b = keras.Model(func_input, func_output)
trainer_1 = MyTrainer(model_a)
trainer_2 = MyTrainer(model_b)
trainer_1.compile(
keras.optimizers.SGD(), loss="sparse_categorical_crossentropy", metrics=["accuracy"]
)
trainer_1.fit(
x_train, y_train, epochs=5, batch_size=64, validation_data=(x_test, y_test)
)
trainer_2.compile(
keras.optimizers.Adam(),
loss="sparse_categorical_crossentropy",
metrics=["accuracy"],
)
trainer_2.fit(
x_train, y_train, epochs=5, batch_size=64, validation_data=(x_test, y_test)
)
Epoch 1/5
938/938 [==============================] - 1s 1ms/step - loss: 0.9075 - accuracy: 0.7650 - val_loss: 0.4641 - val_accuracy: 0.8845
Epoch 2/5
938/938 [==============================] - 1s 1ms/step - loss: 0.4558 - accuracy: 0.8724 - val_loss: 0.3592 - val_accuracy: 0.9044
Epoch 3/5
938/938 [==============================] - 1s 1ms/step - loss: 0.3855 - accuracy: 0.8913 - val_loss: 0.3178 - val_accuracy: 0.9136
Epoch 4/5
938/938 [==============================] - 1s 1ms/step - loss: 0.3465 - accuracy: 0.9014 - val_loss: 0.2908 - val_accuracy: 0.9194
Epoch 5/5
938/938 [==============================] - 1s 1ms/step - loss: 0.3200 - accuracy: 0.9086 - val_loss: 0.2711 - val_accuracy: 0.9252
Epoch 1/5
938/938 [==============================] - 2s 2ms/step - loss: 0.2716 - accuracy: 0.9204 - val_loss: 0.1237 - val_accuracy: 0.9626
Epoch 2/5
938/938 [==============================] - 2s 2ms/step - loss: 0.1270 - accuracy: 0.9625 - val_loss: 0.0869 - val_accuracy: 0.9738
Epoch 3/5
938/938 [==============================] - 2s 2ms/step - loss: 0.0951 - accuracy: 0.9718 - val_loss: 0.0792 - val_accuracy: 0.9747
Epoch 4/5
938/938 [==============================] - 2s 2ms/step - loss: 0.0760 - accuracy: 0.9767 - val_loss: 0.0680 - val_accuracy: 0.9780
Epoch 5/5
938/938 [==============================] - 2s 2ms/step - loss: 0.0647 - accuracy: 0.9798 - val_loss: 0.0698 - val_accuracy: 0.9782
<keras.callbacks.History at 0x168dcb460>