plot_model
functiontf.keras.utils.plot_model(
model,
to_file="model.png",
show_shapes=False,
show_dtype=False,
show_layer_names=True,
rankdir="TB",
expand_nested=False,
dpi=96,
)
Converts a Keras model to dot format and save to a file.
Example
input = tf.keras.Input(shape=(100,), dtype='int32', name='input')
x = tf.keras.layers.Embedding(
output_dim=512, input_dim=10000, input_length=100)(input)
x = tf.keras.layers.LSTM(32)(x)
x = tf.keras.layers.Dense(64, activation='relu')(x)
x = tf.keras.layers.Dense(64, activation='relu')(x)
x = tf.keras.layers.Dense(64, activation='relu')(x)
output = tf.keras.layers.Dense(1, activation='sigmoid', name='output')(x)
model = tf.keras.Model(inputs=[input], outputs=[output])
dot_img_file = '/tmp/model_1.png'
tf.keras.utils.plot_model(model, to_file=dot_img_file, show_shapes=True)
Arguments
rankdir
argument passed to PyDot,
a string specifying the format of the plot:
'TB' creates a vertical plot;
'LR' creates a horizontal plot.Returns
A Jupyter notebook Image object if Jupyter is installed. This enables in-line display of the model plots in notebooks.
model_to_dot
functiontf.keras.utils.model_to_dot(
model,
show_shapes=False,
show_dtype=False,
show_layer_names=True,
rankdir="TB",
expand_nested=False,
dpi=96,
subgraph=False,
)
Convert a Keras model to dot format.
Arguments
rankdir
argument passed to PyDot,
a string specifying the format of the plot:
'TB' creates a vertical plot;
'LR' creates a horizontal plot.pydot.Cluster
instance.Returns
A pydot.Dot
instance representing the Keras model or
a pydot.Cluster
instance representing nested model if
subgraph=True
.
Raises