OrbaxCheckpoint classkeras.callbacks.OrbaxCheckpoint(
directory,
monitor="val_loss",
verbose=0,
save_best_only=False,
mode="auto",
save_freq="epoch",
initial_value_threshold=None,
max_to_keep=1,
save_on_background=True,
save_weights_only=False,
)
Callback to save and load model state using Orbax with a similar API to ModelCheckpoint.
This callback saves the model's weights and optimizer state asynchronously using Orbax, allowing training to continue without blocking for I/O.
Multi-host Support: When running in a multi-host distributed training environment with JAX backend, this callback automatically coordinates checkpointing across all hosts to ensure consistency and proper synchronization. Multi-host checkpointing is only supported on JAX.
Example
model.compile(loss=..., optimizer=..., metrics=['accuracy'])
EPOCHS = 10
checkpoint_dir = '/tmp/ckpt'
orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint(
directory=checkpoint_dir,
monitor='val_accuracy',
mode='max',
save_best_only=True)
# Model is saved at the end of every epoch, if it's the best seen so far.
model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback])
# Alternatively, save checkpoints every N batches -
orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint(
directory=checkpoint_dir,
save_freq=100) # Save every 100 batches
model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback])
Arguments
save_best_only=True, it only saves when the model
is considered the "best" based on the monitored quantity.save_best_only.'epoch' or integer. Frequency to save checkpoints.save_best_only.Guides and examples using OrbaxCheckpoint