plot_model
functiontf_keras.utils.plot_model(
model,
to_file="model.png",
show_shapes=False,
show_dtype=False,
show_layer_names=True,
rankdir="TB",
expand_nested=False,
dpi=96,
layer_range=None,
show_layer_activations=False,
show_trainable=False,
)
Converts a TF-Keras model to dot format and save to a file.
Example
input = tf.keras.Input(shape=(100,), dtype='int32', name='input')
x = tf.keras.layers.Embedding(
output_dim=512, input_dim=10000, input_length=100)(input)
x = tf.keras.layers.LSTM(32)(x)
x = tf.keras.layers.Dense(64, activation='relu')(x)
x = tf.keras.layers.Dense(64, activation='relu')(x)
x = tf.keras.layers.Dense(64, activation='relu')(x)
output = tf.keras.layers.Dense(1, activation='sigmoid', name='output')(x)
model = tf.keras.Model(inputs=[input], outputs=[output])
dot_img_file = '/tmp/model_1.png'
tf.keras.utils.plot_model(model, to_file=dot_img_file, show_shapes=True)
Arguments
rankdir
argument passed to PyDot,
a string specifying the format of the plot: 'TB' creates a vertical
plot; 'LR' creates a horizontal plot.list
containing two str
items, which is the
starting layer name and ending layer name (both inclusive) indicating
the range of layers for which the plot will be generated. It also
accepts regex patterns instead of exact name. In such case, start
predicate will be the first element it matches to layer_range[0]
and
the end predicate will be the last element it matches to
layer_range[1]
. By default None
which considers all layers of model.
Note that you must pass range such that the resultant subgraph must be
complete.activation
property).Raises
plot_model
is called before the model is built.Returns
A Jupyter notebook Image object if Jupyter is installed. This enables in-line display of the model plots in notebooks.
model_to_dot
functiontf_keras.utils.model_to_dot(
model,
show_shapes=False,
show_dtype=False,
show_layer_names=True,
rankdir="TB",
expand_nested=False,
dpi=96,
subgraph=False,
layer_range=None,
show_layer_activations=False,
show_trainable=False,
)
Convert a TF-Keras model to dot format.
Arguments
rankdir
argument passed to PyDot,
a string specifying the format of the plot:
'TB' creates a vertical plot;
'LR' creates a horizontal plot.pydot.Cluster
instance.list
containing two str
items, which is the
starting layer name and ending layer name (both inclusive) indicating
the range of layers for which the pydot.Dot
will be generated. It
also accepts regex patterns instead of exact name. In such case, start
predicate will be the first element it matches to layer_range[0]
and the end predicate will be the last element it matches to
layer_range[1]
. By default None
which considers all layers of
model. Note that you must pass range such that the resultant subgraph
must be complete.activation
property).Returns
A pydot.Dot
instance representing the TF-Keras model or
a pydot.Cluster
instance representing nested model if
subgraph=True
.
Raises
model_to_dot
is called before the model is built.