save
methodModel.save(
filepath,
overwrite=True,
include_optimizer=True,
save_format=None,
signatures=None,
options=None,
save_traces=True,
)
Saves the model to Tensorflow SavedModel or a single HDF5 file.
Please see tf.keras.models.save_model
or the
Serialization and Saving guide
for details.
Arguments
'tf'
or 'h5'
, indicating whether to save the
model to Tensorflow SavedModel or HDF5. Defaults to 'tf' in TF 2.X,
and 'h5' in TF 1.X.signatures
argument in
tf.saved_model.save
for details.tf.saved_model.SaveOptions
object that specifies options for
saving to SavedModel.True
. Disabling this will decrease serialization time
and reduce file size, but it requires that all custom layers/models
implement a get_config()
method.Example
from keras.models import load_model
model.save('my_model.h5') # creates a HDF5 file 'my_model.h5'
del model # deletes the existing model
# returns a compiled model
# identical to the previous one
model = load_model('my_model.h5')
save_model
functiontf.keras.models.save_model(
model,
filepath,
overwrite=True,
include_optimizer=True,
save_format=None,
signatures=None,
options=None,
save_traces=True,
)
Saves a model as a TensorFlow SavedModel or HDF5 file.
See the Serialization and Saving guide for details.
Usage:
>>> model = tf.keras.Sequential([
... tf.keras.layers.Dense(5, input_shape=(3,)),
... tf.keras.layers.Softmax()])
>>> model.save('/tmp/model')
>>> loaded_model = tf.keras.models.load_model('/tmp/model')
>>> x = tf.random.uniform((10, 3))
>>> assert np.allclose(model.predict(x), loaded_model.predict(x))
Note that model.save()
is an alias for tf.keras.models.save_model()
.
The SavedModel and HDF5 file contains:
Thus models can be reinstantiated in the exact same state, without any of the code used for model definition or training.
Note that the model weights may have different scoped names after being
loaded. Scoped names include the model/layer names, such as
"dense_1/kernel:0"
. It is recommended that you use the layer properties to
access specific variables, e.g. model.get_layer("dense_1").kernel
.
SavedModel serialization format
Keras SavedModel uses tf.saved_model.save
to save the model and all
trackable objects attached to the model (e.g. layers and variables). The model
config, weights, and optimizer are saved in the SavedModel. Additionally, for
every Keras layer attached to the model, the SavedModel stores:
* the config and metadata -- e.g. name, dtype, trainable status * traced call and loss functions, which are stored as TensorFlow subgraphs.
The traced functions allow the SavedModel format to save and load custom layers without the original class definition.
You can choose to not save the traced functions by disabling the save_traces
option. This will decrease the time it takes to save the model and the
amount of disk space occupied by the output SavedModel. If you enable this
option, then you must provide all custom class definitions when loading
the model. See the custom_objects
argument in tf.keras.models.load_model
.
Arguments
pathlib.Path
object, path where to save the modelh5py.File
object where to save the modelsignatures
argument in
tf.saved_model.save
for details.tf.saved_model.SaveOptions
object that specifies options for saving to SavedModel.True
. Disabling this will decrease serialization time and
reduce file size, but it requires that all custom layers/models
implement a get_config()
method.Raises
load_model
functiontf.keras.models.load_model(
filepath, custom_objects=None, compile=True, options=None
)
Loads a model saved via model.save()
.
Usage:
>>> model = tf.keras.Sequential([
... tf.keras.layers.Dense(5, input_shape=(3,)),
... tf.keras.layers.Softmax()])
>>> model.save('/tmp/model')
>>> loaded_model = tf.keras.models.load_model('/tmp/model')
>>> x = tf.random.uniform((10, 3))
>>> assert np.allclose(model.predict(x), loaded_model.predict(x))
Note that the model weights may have different scoped names after being
loaded. Scoped names include the model/layer names, such as
"dense_1/kernel:0"
. It is recommended that you use the layer properties to
access specific variables, e.g. model.get_layer("dense_1").kernel
.
Arguments
pathlib.Path
object, path to the saved model
- h5py.File
object from which to load the modeltf.saved_model.LoadOptions
object that specifies
options for loading from SavedModel.Returns
A Keras model instance. If the original model was compiled, and saved with
the optimizer, then the returned model will be compiled. Otherwise, the
model will be left uncompiled. In the case that an uncompiled model is
returned, a warning is displayed if the compile
argument is set to
True
.
Raises
get_weights
methodModel.get_weights()
Retrieves the weights of the model.
Returns
A flat list of Numpy arrays.
set_weights
methodModel.set_weights(weights)
Sets the weights of the layer, from NumPy arrays.
The weights of a layer represent the state of the layer. This function sets the weight values from numpy arrays. The weight values should be passed in the order they are created by the layer. Note that the layer's weights must be instantiated before calling this function, by calling the layer.
For example, a Dense
layer returns a list of two values: the kernel matrix
and the bias vector. These can be used to set the weights of another
Dense
layer:
>>> layer_a = tf.keras.layers.Dense(1,
... kernel_initializer=tf.constant_initializer(1.))
>>> a_out = layer_a(tf.convert_to_tensor([[1., 2., 3.]]))
>>> layer_a.get_weights()
[array([[1.],
[1.],
[1.]], dtype=float32), array([0.], dtype=float32)]
>>> layer_b = tf.keras.layers.Dense(1,
... kernel_initializer=tf.constant_initializer(2.))
>>> b_out = layer_b(tf.convert_to_tensor([[10., 20., 30.]]))
>>> layer_b.get_weights()
[array([[2.],
[2.],
[2.]], dtype=float32), array([0.], dtype=float32)]
>>> layer_b.set_weights(layer_a.get_weights())
>>> layer_b.get_weights()
[array([[1.],
[1.],
[1.]], dtype=float32), array([0.], dtype=float32)]
Arguments
get_weights
).Raises
save_weights
methodModel.save_weights(filepath, overwrite=True, save_format=None, options=None)
Saves all layer weights.
Either saves in HDF5 or in TensorFlow format based on the save_format
argument.
When saving in HDF5 format, the weight file has:
- layer_names
(attribute), a list of strings
(ordered names of model layers).
- For every layer, a group
named layer.name
- For every such layer group, a group attribute weight_names
,
a list of strings
(ordered names of weights tensor of the layer).
- For every weight in the layer, a dataset
storing the weight value, named after the weight tensor.
When saving in TensorFlow format, all objects referenced by the network are
saved in the same format as tf.train.Checkpoint
, including any Layer
instances or Optimizer
instances assigned to object attributes. For
networks constructed from inputs and outputs using tf.keras.Model(inputs,
outputs)
, Layer
instances used by the network are tracked/saved
automatically. For user-defined classes which inherit from tf.keras.Model
,
Layer
instances must be assigned to object attributes, typically in the
constructor. See the documentation of tf.train.Checkpoint
and
tf.keras.Model
for details.
While the formats are the same, do not mix save_weights
and
tf.train.Checkpoint
. Checkpoints saved by Model.save_weights
should be
loaded using Model.load_weights
. Checkpoints saved using
tf.train.Checkpoint.save
should be restored using the corresponding
tf.train.Checkpoint.restore
. Prefer tf.train.Checkpoint
over
save_weights
for training checkpoints.
The TensorFlow format matches objects and variables by starting at a root
object, self
for save_weights
, and greedily matching attribute
names. For Model.save
this is the Model
, and for Checkpoint.save
this
is the Checkpoint
even if the Checkpoint
has a model attached. This
means saving a tf.keras.Model
using save_weights
and loading into a
tf.train.Checkpoint
with a Model
attached (or vice versa) will not match
the Model
's variables. See the
guide to training checkpoints
for details on the TensorFlow format.
Arguments
filepath
ending in '.h5' or
'.keras' will default to HDF5 if save_format
is None
. Otherwise
None
defaults to 'tf'.tf.train.CheckpointOptions
object that specifies
options for saving weights.Raises
h5py
is not available when attempting to save in HDF5
format.load_weights
methodModel.load_weights(filepath, by_name=False, skip_mismatch=False, options=None)
Loads all layer weights, either from a TensorFlow or an HDF5 weight file.
If by_name
is False weights are loaded based on the network's
topology. This means the architecture should be the same as when the weights
were saved. Note that layers that don't have weights are not taken into
account in the topological ordering, so adding or removing layers is fine as
long as they don't have weights.
If by_name
is True, weights are loaded into layers only if they share the
same name. This is useful for fine-tuning or transfer-learning models where
some of the layers have changed.
Only topological loading (by_name=False
) is supported when loading weights
from the TensorFlow format. Note that topological loading differs slightly
between TensorFlow and HDF5 formats for user-defined classes inheriting from
tf.keras.Model
: HDF5 loads based on a flattened list of weights, while the
TensorFlow format loads based on the object-local names of attributes to
which layers are assigned in the Model
's constructor.
Arguments
save_weights
). This can also be a path to a SavedModel
saved from model.save
.by_name=True
).tf.train.CheckpointOptions
object that specifies
options for loading weights.Returns
When loading a weight file in TensorFlow format, returns the same status
object as tf.train.Checkpoint.restore
. When graph building, restore
ops are run automatically as soon as the network is built (on first call
for user-defined classes inheriting from Model
, immediately if it is
already built).
When loading weights in HDF5 format, returns None
.
Raises
h5py
is not available and the weight file is in HDF5
format.skip_mismatch
is set to True
when by_name
is
False
.get_config
methodModel.get_config()
Returns the config of the Model
.
Config is a Python dictionary (serializable) containing the configuration of
an object, which in this case is a Model
. This allows the Model
to be
be reinstantiated later (without its trained weights) from this
configuration.
Note that get_config()
does not guarantee to return a fresh copy of dict
every time it is called. The callers should make a copy of the returned dict
if they want to modify it.
Developers of subclassed Model
are advised to override this method, and
continue to update the dict from super(MyModel, self).get_config()
to provide the proper configuration of this Model
. The default config
is an empty dict. Optionally, raise NotImplementedError
to allow Keras to
attempt a default serialization.
Returns
Python dictionary containing the configuration of this Model
.
from_config
methodModel.from_config(config, custom_objects=None)
Creates a layer from its config.
This method is the reverse of get_config
,
capable of instantiating the same layer from the config
dictionary. It does not handle layer connectivity
(handled by Network), nor weights (handled by set_weights
).
Arguments
Returns
A layer instance.
model_from_config
functiontf.keras.models.model_from_config(config, custom_objects=None)
Instantiates a Keras model from its config.
Usage:
# for a Functional API model
tf.keras.Model().from_config(model.get_config())
# for a Sequential model
tf.keras.Sequential().from_config(model.get_config())
Arguments
Returns
A Keras model instance (uncompiled).
Raises
config
is not a dictionary.to_json
methodModel.to_json(**kwargs)
Returns a JSON string containing the network configuration.
To load a network from a JSON save file, use
keras.models.model_from_json(json_string, custom_objects={})
.
Arguments
json.dumps()
.Returns
A JSON string.
model_from_json
functiontf.keras.models.model_from_json(json_string, custom_objects=None)
Parses a JSON model configuration string and returns a model instance.
Usage:
>>> model = tf.keras.Sequential([
... tf.keras.layers.Dense(5, input_shape=(3,)),
... tf.keras.layers.Softmax()])
>>> config = model.to_json()
>>> loaded_model = tf.keras.models.model_from_json(config)
Arguments
Returns
A Keras model instance (uncompiled).
clone_model
functiontf.keras.models.clone_model(model, input_tensors=None, clone_function=None)
Clone a Functional or Sequential Model
instance.
Model cloning is similar to calling a model on new inputs, except that it creates new layers (and thus new weights) instead of sharing the weights of the existing layers.
Note that
clone_model
will not preserve the uniqueness of shared objects within the
model (e.g. a single variable attached to two distinct layers will be
restored as two separate variables).
Arguments
Model
(could be a Functional model or a Sequential model).Input
objects will be created.InputLayer
instances). It takes as argument the layer
instance to be cloned, and returns the corresponding layer instance to
be used in the model copy. If unspecified, this callable defaults to
the following serialization/deserialization function:
lambda layer: layer.__class__.from_config(layer.get_config())
.
By passing a custom callable, you can customize your copy of the
model, e.g. by wrapping certain layers of interest (you might want to
replace all LSTM
instances with equivalent
Bidirectional(LSTM(...))
instances, for example).Returns
An instance of Model
reproducing the behavior
of the original model, on top of new inputs tensors,
using newly instantiated weights. The cloned model may behave
differently from the original model if a custom clone_function
modifies the layer.
Example
# Create a test Sequential model.
model = keras.Sequential([
keras.Input(shape=(728,)),
keras.layers.Dense(32, activation='relu'),
keras.layers.Dense(1, activation='sigmoid'),
])
# Create a copy of the test model (with freshly initialized weights).
new_model = clone_model(model)
Note that subclassed models cannot be cloned, since their internal
layer structure is not known. To achieve equivalent functionality
as clone_model
in the case of a subclassed model, simply make sure
that the model class implements get_config()
(and optionally from_config()
), and call:
new_model = model.__class__.from_config(model.get_config())