Author: Samaneh Saadat, Amit Srivastava
Date created: 2025/11/20
Last modified: 2026/03/04
Description: Save and load Orbax checkpoints with distributed resharding.
Orbax is the recommended checkpointing library for the JAX ecosystem. It provides high-level functionality for checkpoint management, composable serialization, and multi-host coordination.
Starting with Keras 3.14, the built-in keras.callbacks.OrbaxCheckpoint callback
makes it easy to:
save_best_only monitoring (just like ModelCheckpoint).Install the Orbax checkpointing library:
!pip install -q -U orbax-checkpoint
[[34;49mnotice[1;39;49m][39;49m To update, run: [32;49mpip install --upgrade pip
Set the Keras backend to JAX, configure virtual devices for the distributed demo, and import the required libraries.
import os
os.environ["KERAS_BACKEND"] = "jax"
import shutil
import jax
import keras
import numpy as np
# Simulate 4 CPU devices for the distributed demo.
# Remove this line if using real multi-device hardware.
jax.config.update("jax_num_cpu_devices", 4)
OrbaxCheckpoint works like ModelCheckpoint — pass it as a callback to
model.fit(). No boilerplate classes or wrappers are needed.
def get_model():
inputs = keras.Input(shape=(32,))
x = keras.layers.Dense(64, activation="relu", name="dense_1")(inputs)
outputs = keras.layers.Dense(1, name="dense_2")(x)
model = keras.Model(inputs, outputs)
model.compile(optimizer="adam", loss="mse")
return model
model = get_model()
x_train = np.random.random((256, 32))
y_train = np.random.random((256, 1))
checkpoint_dir = "/tmp/orbax_ckpt_basic"
shutil.rmtree(checkpoint_dir, ignore_errors=True)
callback = keras.callbacks.OrbaxCheckpoint(
directory=checkpoint_dir,
max_to_keep=3,
)
history = model.fit(
x_train,
y_train,
batch_size=32,
epochs=3,
verbose=1,
validation_split=0.2,
callbacks=[callback],
)
Epoch 1/3
7/7 ━━━━━━━━━━━━━━━━━━━━ 3s 518ms/step - loss: 0.1514 - val_loss: 0.1539
Epoch 2/3
1/7 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - loss: 0.1336
WARNING:absl:[process=0][thread=Thread-10 (_target_setting_result)] Skipping merge of OCDBT checkpoints: No per-process OCDBT checkpoint subdirs found in /tmp/orbax_ckpt_basic/0.orbax-checkpoint-tmp/model_config.orbax-checkpoint-tmp,
7/7 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.1314 - val_loss: 0.1384
Epoch 3/3
1/7 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0898
WARNING:absl:[process=0][thread=Thread-16 (_target_setting_result)] Skipping merge of OCDBT checkpoints: No per-process OCDBT checkpoint subdirs found in /tmp/orbax_ckpt_basic/1.orbax-checkpoint-tmp/model_config.orbax-checkpoint-tmp,
7/7 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.1183 - val_loss: 0.1264
WARNING:absl:[process=0][thread=Thread-23 (_target_setting_result)] Skipping merge of OCDBT checkpoints: No per-process OCDBT checkpoint subdirs found in /tmp/orbax_ckpt_basic/2.orbax-checkpoint-tmp/model_config.orbax-checkpoint-tmp,
The checkpoint directory now contains a step-directory for each saved epoch.
!ls /tmp/orbax_ckpt_basic
/opt/homebrew/Cellar/python@3.13/3.13.7/Frameworks/Python.framework/Versions/3.13/lib/python3.13/pty.py:95: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
pid, fd = os.forkpty()
0[m[m 1[m[m 2[m[m
Use keras.saving.load_model() to reload a full model (config + weights +
optimizer state) from an Orbax checkpoint directory.
loaded_model = keras.saving.load_model(checkpoint_dir)
loaded_model.summary()
/Users/amitsrivasta/work/keras/venv_test/lib/python3.13/site-packages/orbax/checkpoint/_src/serialization/jax_array_handlers.py:701: UserWarning: Sharding info not provided when restoring. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.
warnings.warn(
Model: "functional"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ input_layer (InputLayer) │ (None, 32) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_1 (Dense) │ (None, 64) │ 2,112 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_2 (Dense) │ (None, 1) │ 65 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 6,533 (25.52 KB)
Trainable params: 2,177 (8.50 KB)
Non-trainable params: 0 (0.00 B)
Optimizer params: 4,356 (17.02 KB)
If you already have a model instance and just want to load the weights, use
load_weights():
fresh_model = get_model()
fresh_model.load_weights(checkpoint_dir)
# Verify both loaded_model and fresh_model match the original.
for m in [loaded_model, fresh_model]:
for orig, restored in zip(model.weights, m.weights):
np.testing.assert_allclose(orig.numpy(), restored.numpy(), atol=1e-6)
print("Weights match!")
Weights match!
When you pass OrbaxCheckpoint to a new fit() call, training resumes from
the correct step number. Pass initial_epoch to continue the epoch
count from where you left off.
resumed_model = keras.saving.load_model(checkpoint_dir)
# Create a new callback pointing to the same directory.
resume_callback = keras.callbacks.OrbaxCheckpoint(
directory=checkpoint_dir,
max_to_keep=5,
)
# Continue training — epochs 3 and 4 (picking up from epoch 3).
resumed_model.fit(
x_train,
y_train,
batch_size=32,
epochs=5,
initial_epoch=3,
verbose=1,
validation_split=0.2,
callbacks=[resume_callback],
)
Epoch 4/5
7/7 ━━━━━━━━━━━━━━━━━━━━ 0s 64ms/step - loss: 0.1094 - val_loss: 0.1176
Epoch 5/5
1/7 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.1179
WARNING:absl:[process=0][thread=Thread-29 (_target_setting_result)] Skipping merge of OCDBT checkpoints: No per-process OCDBT checkpoint subdirs found in /tmp/orbax_ckpt_basic/3.orbax-checkpoint-tmp/model_config.orbax-checkpoint-tmp,
7/7 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0996 - val_loss: 0.1141
WARNING:absl:[process=0][thread=Thread-35 (_target_setting_result)] Skipping merge of OCDBT checkpoints: No per-process OCDBT checkpoint subdirs found in /tmp/orbax_ckpt_basic/4.orbax-checkpoint-tmp/model_config.orbax-checkpoint-tmp,
<keras.src.callbacks.history.History at 0x1483120d0>
Monitor a metric and keep only the best checkpoint:
best_dir = "/tmp/orbax_ckpt_best"
shutil.rmtree(best_dir, ignore_errors=True)
best_callback = keras.callbacks.OrbaxCheckpoint(
directory=best_dir,
monitor="val_loss",
mode="min",
save_best_only=True,
max_to_keep=1,
)
model_best = get_model()
model_best.fit(
x_train,
y_train,
batch_size=32,
epochs=5,
verbose=1,
validation_split=0.2,
callbacks=[best_callback],
)
Epoch 1/5
7/7 ━━━━━━━━━━━━━━━━━━━━ 0s 50ms/step - loss: 0.1190 - val_loss: 0.1004
Epoch 2/5
1/7 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.1280
WARNING:absl:[process=0][thread=Thread-41 (_target_setting_result)] Skipping merge of OCDBT checkpoints: No per-process OCDBT checkpoint subdirs found in /tmp/orbax_ckpt_best/0.orbax-checkpoint-tmp/model_config.orbax-checkpoint-tmp,
7/7 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.1061 - val_loss: 0.0984
Epoch 3/5
1/7 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.1057
WARNING:absl:[process=0][thread=Thread-47 (_target_setting_result)] Skipping merge of OCDBT checkpoints: No per-process OCDBT checkpoint subdirs found in /tmp/orbax_ckpt_best/1.orbax-checkpoint-tmp/model_config.orbax-checkpoint-tmp,
7/7 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0967 - val_loss: 0.0876
Epoch 4/5
1/7 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.1130
WARNING:absl:[process=0][thread=Thread-52 (_target_setting_result)] Skipping merge of OCDBT checkpoints: No per-process OCDBT checkpoint subdirs found in /tmp/orbax_ckpt_best/2.orbax-checkpoint-tmp/model_config.orbax-checkpoint-tmp,
7/7 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.0893 - val_loss: 0.0881
Epoch 5/5
7/7 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0861 - val_loss: 0.0797
WARNING:absl:[process=0][thread=Thread-59 (_target_setting_result)] Skipping merge of OCDBT checkpoints: No per-process OCDBT checkpoint subdirs found in /tmp/orbax_ckpt_best/4.orbax-checkpoint-tmp/model_config.orbax-checkpoint-tmp,
<keras.src.callbacks.history.History at 0x14856aad0>
Save every N batches instead of every epoch by setting save_freq to an
integer:
batch_dir = "/tmp/orbax_ckpt_batch"
shutil.rmtree(batch_dir, ignore_errors=True)
batch_callback = keras.callbacks.OrbaxCheckpoint(
directory=batch_dir,
save_freq=4, # Save every 4 batches.
max_to_keep=3,
)
model_batch = get_model()
model_batch.fit(
x_train,
y_train,
batch_size=32,
epochs=2,
verbose=1,
callbacks=[batch_callback],
)
Epoch 1/2
1/8 ━━━━━━━━━━━━━━━━━━━━ 0s 85ms/step - loss: 0.9480
WARNING:absl:[process=0][thread=Thread-65 (_target_setting_result)] Skipping merge of OCDBT checkpoints: No per-process OCDBT checkpoint subdirs found in /tmp/orbax_ckpt_batch/4.orbax-checkpoint-tmp/model_config.orbax-checkpoint-tmp,
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.6971
Epoch 2/2
1/8 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.3194
WARNING:absl:[process=0][thread=Thread-70 (_target_setting_result)] Skipping merge of OCDBT checkpoints: No per-process OCDBT checkpoint subdirs found in /tmp/orbax_ckpt_batch/8.orbax-checkpoint-tmp/model_config.orbax-checkpoint-tmp,
WARNING:absl:[process=0][thread=Thread-77 (_target_setting_result)] Skipping merge of OCDBT checkpoints: No per-process OCDBT checkpoint subdirs found in /tmp/orbax_ckpt_batch/12.orbax-checkpoint-tmp/model_config.orbax-checkpoint-tmp,
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2006
WARNING:absl:[process=0][thread=Thread-82 (_target_setting_result)] Skipping merge of OCDBT checkpoints: No per-process OCDBT checkpoint subdirs found in /tmp/orbax_ckpt_batch/16.orbax-checkpoint-tmp/model_config.orbax-checkpoint-tmp,
<keras.src.callbacks.history.History at 0x13b892060>
OrbaxCheckpoint works seamlessly with the Keras Distribution API. When you
set a distribution, model variables are automatically sharded according to
your LayoutMap and checkpoints capture the distributed state.
Note: The example below uses virtual devices to simulate a multi-device environment. In production, use your actual accelerators (GPUs/TPUs).
devices = jax.devices()
print(f"Available devices: {len(devices)}")
Available devices: 4
mesh = keras.distribution.DeviceMesh(
shape=(2, 2),
axis_names=["data", "model"],
devices=devices,
)
layout_map = keras.distribution.LayoutMap(mesh)
layout_map["dense_1/kernel"] = (None, "model")
distribution = keras.distribution.ModelParallel(
layout_map=layout_map, batch_dim_name="data"
)
keras.distribution.set_distribution(distribution)
dist_dir = "/tmp/orbax_ckpt_dist"
shutil.rmtree(dist_dir, ignore_errors=True)
dist_model = get_model()
dist_callback = keras.callbacks.OrbaxCheckpoint(
directory=dist_dir,
max_to_keep=2,
)
dist_model.fit(
x_train,
y_train,
batch_size=32,
epochs=3,
verbose=1,
validation_split=0.2,
callbacks=[dist_callback],
)
Epoch 1/3
7/7 ━━━━━━━━━━━━━━━━━━━━ 1s 83ms/step - loss: 0.3819 - val_loss: 0.1660
Epoch 2/3
WARNING:absl:[process=0][thread=Thread-89 (_target_setting_result)] Skipping merge of OCDBT checkpoints: No per-process OCDBT checkpoint subdirs found in /tmp/orbax_ckpt_dist/0.orbax-checkpoint-tmp/model_config.orbax-checkpoint-tmp,
7/7 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.1216 - val_loss: 0.1554
Epoch 3/3
WARNING:absl:[process=0][thread=Thread-95 (_target_setting_result)] Skipping merge of OCDBT checkpoints: No per-process OCDBT checkpoint subdirs found in /tmp/orbax_ckpt_dist/1.orbax-checkpoint-tmp/model_config.orbax-checkpoint-tmp,
7/7 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.1520 - val_loss: 0.1450
WARNING:absl:[process=0][thread=Thread-100 (_target_setting_result)] Skipping merge of OCDBT checkpoints: No per-process OCDBT checkpoint subdirs found in /tmp/orbax_ckpt_dist/2.orbax-checkpoint-tmp/model_config.orbax-checkpoint-tmp,
<keras.src.callbacks.history.History at 0x13b893ce0>
After training, you can inspect the sharding of saved variables:
loaded_dist = keras.saving.load_model(dist_dir)
for v in loaded_dist.trainable_variables:
print(f"{v.path}: shape={v.shape}, sharding={v.value.sharding}")
WARNING:absl:TensorStore data files not found in checkpoint path /tmp/orbax_ckpt_dist/2. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
dense_1/kernel: shape=(32, 64), sharding=NamedSharding(mesh=Mesh('data': 2, 'model': 2, axis_types=(Auto, Auto)), spec=PartitionSpec(None, 'model'), memory_kind=device)
dense_1/bias: shape=(64,), sharding=NamedSharding(mesh=Mesh('data': 2, 'model': 2, axis_types=(Auto, Auto)), spec=PartitionSpec(None,), memory_kind=device)
dense_2/kernel: shape=(64, 1), sharding=NamedSharding(mesh=Mesh('data': 2, 'model': 2, axis_types=(Auto, Auto)), spec=PartitionSpec(None, None), memory_kind=device)
dense_2/bias: shape=(1,), sharding=NamedSharding(mesh=Mesh('data': 2, 'model': 2, axis_types=(Auto, Auto)), spec=PartitionSpec(None,), memory_kind=device)
One of the most powerful features of Orbax checkpointing in Keras is the ability to load a checkpoint saved with one sharding layout and restore it under a different layout. Keras automatically provides the target shardings to Orbax so arrays are resharded on load.
For example, a checkpoint saved with dense_1/kernel sharded on the
"model" axis can be loaded with that same kernel sharded on the "data"
axis instead:
# Clear the current distribution first.
keras.distribution.set_distribution(None)
# Define a new layout that shards dense_1/kernel on a different axis.
new_layout_map = keras.distribution.LayoutMap(mesh)
new_layout_map["dense_1/kernel"] = ("data", None)
new_distribution = keras.distribution.ModelParallel(
layout_map=new_layout_map, batch_dim_name="data"
)
keras.distribution.set_distribution(new_distribution)
# Load the checkpoint saved with the original layout.
resharded_model = keras.saving.load_model(dist_dir)
print("\nResharded variable shardings:")
for v in resharded_model.trainable_variables:
print(f" {v.path}: sharding={v.value.sharding}")
WARNING:absl:TensorStore data files not found in checkpoint path /tmp/orbax_ckpt_dist/2. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
Resharded variable shardings:
dense_1/kernel: sharding=NamedSharding(mesh=Mesh('data': 2, 'model': 2, axis_types=(Auto, Auto)), spec=PartitionSpec('data', None), memory_kind=device)
dense_1/bias: sharding=NamedSharding(mesh=Mesh('data': 2, 'model': 2, axis_types=(Auto, Auto)), spec=PartitionSpec(None,), memory_kind=device)
dense_2/kernel: sharding=NamedSharding(mesh=Mesh('data': 2, 'model': 2, axis_types=(Auto, Auto)), spec=PartitionSpec(None, None), memory_kind=device)
dense_2/bias: sharding=NamedSharding(mesh=Mesh('data': 2, 'model': 2, axis_types=(Auto, Auto)), spec=PartitionSpec(None,), memory_kind=device)
The dense_1/kernel variable was originally sharded as (None, "model")
but is now sharded as ("data", None). Orbax handles the data movement
automatically during loading.
| Parameter | Default | Description |
|---|---|---|
directory |
(required) | Path to checkpoint directory. |
monitor |
"val_loss" |
Metric to monitor for save_best_only. |
mode |
"auto" |
"min", "max", or "auto". |
save_best_only |
False |
Only save when monitored metric improves. |
save_freq |
"epoch" |
"epoch" or an integer (every N batches). |
max_to_keep |
1 |
Max recent checkpoints to retain. |
save_on_background |
True |
Save asynchronously to avoid blocking. |
save_weights_only |
False |
Save only weights (no model config/assets). |
initial_value_threshold |
None |
Initial "best" value for the monitor. |
verbose |
0 |
Verbosity (0 = silent, 1 = messages). |
keras.callbacks.OrbaxCheckpoint is the built-in callback for Orbax
checkpointing — no wrapper classes needed.keras.saving.load_model() or model.load_weights() to
restore from an Orbax checkpoint directory.LayoutMap.ModelCheckpoint API: save_best_only, monitor, mode,
save_freq, max_to_keep.# Clean up the distribution for any subsequent cells.
keras.distribution.set_distribution(None)