RematScope
classkeras.RematScope(mode="full", output_size_threshold=1024, layer_names=None)
A context manager for enabling rematerialization in Keras.
Rematerialization (gradient checkpointing) trades memory for computation by recomputing intermediate activations during the backward pass. This is particularly useful for training large models or large batch sizes within limited memory constraints.
This should be used when initializing the layer (e.g., layer(input)
).
Rematerialization applies at execution time, not at creation time.
Arguments
"full"
: Apply rematerialization globally to all supported
operations."activations"
: Apply rematerialization to activations on any
layers that contain keras.activations
(e.g., Dense(...,
activation=relu)
)."larger_than"
: Apply rematerialization to layers with output
sizes larger than output_size_threshold
."list_of_layers"
: Apply rematerialization to a specific list of
layer names.None
: Disable rematerialization."larger_than"
mode. Layers producing outputs larger than this
threshold will be rematerialized. Default is 1024
."list_of_layers"
mode. Default is an empty list.Examples
Using "list_of_layers" mode:
from keras import RematScope
input_tensor = tf.random.normal((1, 32, 32, 3))
with RematScope(mode="list_of_layers", layer_names=["dense_1",
"conv2d_1"]):
layer1 = keras.layers.Dense(128, name="dense_1")
layer2 = keras.layers.Conv2D(64, (3, 3), name="conv2d_1")
layer3 = keras.layers.Dense(64, name="dense_2")
# Only layer1 and layer2 will apply rematerialization
output1 = layer1(input_tensor)
output2 = layer2(output1)
output3 = layer3(output2)
Using "larger_than" mode with a specific output size threshold:
with RematScope(mode="larger_than", output_size_threshold=2048):
layer = keras.layers.Conv2D(64, (3, 3))
output = layer(input_tensor) # Conv2D outputs larger than 2048
Nested scopes for fine-grained control:
with RematScope(mode="full"):
# Create layers
layer1 = keras.layers.Dense(128, activation='relu')
output1 = layer1(input_tensor) # layer1 is fully rematerialized
with RematScope(mode="larger_than", output_size_threshold=512):
layer2 = keras.layers.Conv2D(32, (3, 3))
output2 = layer2(output1) # layer2 is conditionally rematerialized
# if output > 512