Keras 3 API documentation / Models API / Saving & serialization / Model export for inference

Model export for inference

[source]

export method

Model.export(
    filepath, format="tf_saved_model", verbose=True, input_signature=None, **kwargs
)

Export the model as an artifact for inference.

Arguments

  • filepath: str or pathlib.Path object. The path to save the artifact.
  • format: str. The export format. Supported values: "tf_saved_model" and "onnx". Defaults to "tf_saved_model".
  • verbose: bool. Whether to print a message during export. Defaults to True.
  • input_signature: Optional. Specifies the shape and dtype of the model inputs. Can be a structure of keras.InputSpec, tf.TensorSpec, backend.KerasTensor, or backend tensor. If not provided, it will be automatically computed. Defaults to None.
  • **kwargs: Additional keyword arguments:
    • Specific to the JAX backend and format="tf_saved_model": - is_static: Optional bool. Indicates whether fn is static. Set to False if fn involves state updates (e.g., RNG seeds and counters). - jax2tf_kwargs: Optional dict. Arguments for jax2tf.convert. See the documentation for jax2tf.convert. If native_serialization and polymorphic_shapes are not provided, they will be automatically computed.

Note: This feature is currently supported only with TensorFlow, JAX and Torch backends.

Examples

Here's how to export a TensorFlow SavedModel for inference.

# Export the model as a TensorFlow SavedModel artifact
model.export("path/to/location", format="tf_saved_model")

# Load the artifact in a different process/environment
reloaded_artifact = tf.saved_model.load("path/to/location")
predictions = reloaded_artifact.serve(input_data)

Here's how to export an ONNX for inference.

# Export the model as a ONNX artifact
model.export("path/to/location", format="onnx")

# Load the artifact in a different process/environment
ort_session = onnxruntime.InferenceSession("path/to/location")
ort_inputs = {
    k.name: v for k, v in zip(ort_session.get_inputs(), input_data)
}
predictions = ort_session.run(None, ort_inputs)

[source]

ExportArchive class

keras.export.ExportArchive()

ExportArchive is used to write SavedModel artifacts (e.g. for inference).

If you have a Keras model or layer that you want to export as SavedModel for serving (e.g. via TensorFlow-Serving), you can use ExportArchive to configure the different serving endpoints you need to make available, as well as their signatures. Simply instantiate an ExportArchive, use track() to register the layer(s) or model(s) to be used, then use the add_endpoint() method to register a new serving endpoint. When done, use the write_out() method to save the artifact.

The resulting artifact is a SavedModel and can be reloaded via tf.saved_model.load.

Examples

Here's how to export a model for inference.

export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
    name="serve",
    fn=model.call,
    input_signature=[keras.InputSpec(shape=(None, 3), dtype="float32")],
)
export_archive.write_out("path/to/location")

# Elsewhere, we can reload the artifact and serve it.
# The endpoint we added is available as a method:
serving_model = tf.saved_model.load("path/to/location")
outputs = serving_model.serve(inputs)

Here's how to export a model with one endpoint for inference and one endpoint for a training-mode forward pass (e.g. with dropout on).

export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
    name="call_inference",
    fn=lambda x: model.call(x, training=False),
    input_signature=[keras.InputSpec(shape=(None, 3), dtype="float32")],
)
export_archive.add_endpoint(
    name="call_training",
    fn=lambda x: model.call(x, training=True),
    input_signature=[keras.InputSpec(shape=(None, 3), dtype="float32")],
)
export_archive.write_out("path/to/location")

Note on resource tracking:

ExportArchive is able to automatically track all keras.Variables used by its endpoints, so most of the time calling .track(model) is not strictly required. However, if your model uses lookup layers such as IntegerLookup, StringLookup, or TextVectorization, it will need to be tracked explicitly via .track(model).

Explicit tracking is also required if you need to be able to access the properties variables, trainable_variables, or non_trainable_variables on the revived archive.


[source]

add_endpoint method

ExportArchive.add_endpoint(name, fn, input_signature=None, **kwargs)

Register a new serving endpoint.

Arguments

  • name: str. The name of the endpoint.
  • fn: A callable. It should only leverage resources (e.g. keras.Variable objects or tf.lookup.StaticHashTable objects) that are available on the models/layers tracked by the ExportArchive (you can call .track(model) to track a new model). The shape and dtype of the inputs to the function must be known. For that purpose, you can either 1) make sure that fn is a tf.function that has been called at least once, or 2) provide an input_signature argument that specifies the shape and dtype of the inputs (see below).
  • input_signature: Optional. Specifies the shape and dtype of fn. Can be a structure of keras.InputSpec, tf.TensorSpec, backend.KerasTensor, or backend tensor (see below for an example showing a Functional model with 2 input arguments). If not provided, fn must be a tf.function that has been called at least once. Defaults to None.
  • **kwargs: Additional keyword arguments:
    • Specific to the JAX backend: - is_static: Optional bool. Indicates whether fn is static. Set to False if fn involves state updates (e.g., RNG seeds). - jax2tf_kwargs: Optional dict. Arguments for jax2tf.convert. See jax2tf.convert. If native_serialization and polymorphic_shapes are not provided, they are automatically computed.

Returns

The tf.function wrapping fn that was added to the archive.

Example

Adding an endpoint using the input_signature argument when the model has a single input argument:

export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
    name="serve",
    fn=model.call,
    input_signature=[keras.InputSpec(shape=(None, 3), dtype="float32")],
)

Adding an endpoint using the input_signature argument when the model has two positional input arguments:

export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
    name="serve",
    fn=model.call,
    input_signature=[
        keras.InputSpec(shape=(None, 3), dtype="float32"),
        keras.InputSpec(shape=(None, 4), dtype="float32"),
    ],
)

Adding an endpoint using the input_signature argument when the model has one input argument that is a list of 2 tensors (e.g. a Functional model with 2 inputs):

model = keras.Model(inputs=[x1, x2], outputs=outputs)

export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
    name="serve",
    fn=model.call,
    input_signature=[
        [
            keras.InputSpec(shape=(None, 3), dtype="float32"),
            keras.InputSpec(shape=(None, 4), dtype="float32"),
        ],
    ],
)

This also works with dictionary inputs:

model = keras.Model(inputs={"x1": x1, "x2": x2}, outputs=outputs)

export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
    name="serve",
    fn=model.call,
    input_signature=[
        {
            "x1": keras.InputSpec(shape=(None, 3), dtype="float32"),
            "x2": keras.InputSpec(shape=(None, 4), dtype="float32"),
        },
    ],
)

Adding an endpoint that is a tf.function:

@tf.function()
def serving_fn(x):
    return model(x)

# The function must be traced, i.e. it must be called at least once.
serving_fn(tf.random.normal(shape=(2, 3)))

export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(name="serve", fn=serving_fn)

[source]

add_variable_collection method

ExportArchive.add_variable_collection(name, variables)

Register a set of variables to be retrieved after reloading.

Arguments

  • name: The string name for the collection.
  • variables: A tuple/list/set of keras.Variable instances.

Example

export_archive = ExportArchive()
export_archive.track(model)
# Register an endpoint
export_archive.add_endpoint(
    name="serve",
    fn=model.call,
    input_signature=[keras.InputSpec(shape=(None, 3), dtype="float32")],
)
# Save a variable collection
export_archive.add_variable_collection(
    name="optimizer_variables", variables=model.optimizer.variables)
export_archive.write_out("path/to/location")

# Reload the object
revived_object = tf.saved_model.load("path/to/location")
# Retrieve the variables
optimizer_variables = revived_object.optimizer_variables

[source]

track method

ExportArchive.track(resource)

Track the variables (and other assets) of a layer or model.

By default, all variables used by an endpoint function are automatically tracked when you call add_endpoint(). However, non-variables assets such as lookup tables need to be tracked manually. Note that lookup tables used by built-in Keras layers (TextVectorization, IntegerLookup, StringLookup) are automatically tracked in add_endpoint().

Arguments

  • resource: A trackable Keras resource, such as a layer or model.

[source]

write_out method

ExportArchive.write_out(filepath, options=None, verbose=True)

Write the corresponding SavedModel to disk.

Arguments

  • filepath: str or pathlib.Path object. Path where to save the artifact.
  • options: tf.saved_model.SaveOptions object that specifies SavedModel saving options.
  • verbose: whether to print all the variables of an exported SavedModel.

Note on TF-Serving: all endpoints registered via add_endpoint() are made visible for TF-Serving in the SavedModel artifact. In addition, the first endpoint registered is made visible under the alias "serving_default" (unless an endpoint with the name "serving_default" was already registered manually), since TF-Serving requires this endpoint to be set.