Keras 3 API documentation / Callbacks API / OrbaxCheckpoint

OrbaxCheckpoint

[source]

OrbaxCheckpoint class

keras.callbacks.OrbaxCheckpoint(
    directory,
    monitor="val_loss",
    verbose=0,
    save_best_only=False,
    mode="auto",
    save_freq="epoch",
    initial_value_threshold=None,
    max_to_keep=1,
    save_on_background=True,
    save_weights_only=False,
)

Callback to save and load model state using Orbax with a similar API to ModelCheckpoint.

This callback saves the model's weights and optimizer state asynchronously using Orbax, allowing training to continue without blocking for I/O.

Multi-host Support: When running in a multi-host distributed training environment with JAX backend, this callback automatically coordinates checkpointing across all hosts to ensure consistency and proper synchronization. Multi-host checkpointing is only supported on JAX.

Example

model.compile(loss=..., optimizer=..., metrics=['accuracy'])

EPOCHS = 10
checkpoint_dir = '/tmp/ckpt'
orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint(
    directory=checkpoint_dir,
    monitor='val_accuracy',
    mode='max',
    save_best_only=True)

# Model is saved at the end of every epoch, if it's the best seen so far.
model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback])

# Alternatively, save checkpoints every N batches -
orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint(
    directory=checkpoint_dir,
    save_freq=100)  # Save every 100 batches

model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback])

Arguments

  • directory: path to the directory where to save the checkpoints.
  • monitor: The metric name to monitor (e.g., 'val_loss').
  • verbose: Verbosity mode, 0 or 1.
  • save_best_only: if save_best_only=True, it only saves when the model is considered the "best" based on the monitored quantity.
  • mode: one of {'auto', 'min', 'max'}. Used with save_best_only.
  • save_freq: 'epoch' or integer. Frequency to save checkpoints.
  • max_to_keep: Integer, maximum number of recent checkpoints to keep. If None, keeps all. Defaults to 1.
  • save_on_background: Boolean, whether to save asynchronously in the background. Defaults to True.
  • initial_value_threshold: Floating point initial "best" value for the monitor, used with save_best_only.

Guides and examples using OrbaxCheckpoint