ยป Keras API reference / Utilities / Model plotting utilities

Model plotting utilities

plot_model function

tf.keras.utils.plot_model(
    model,
    to_file="model.png",
    show_shapes=False,
    show_dtype=False,
    show_layer_names=True,
    rankdir="TB",
    expand_nested=False,
    dpi=96,
)

Converts a Keras model to dot format and save to a file.

Example

input = tf.keras.Input(shape=(100,), dtype='int32', name='input')
x = tf.keras.layers.Embedding(
    output_dim=512, input_dim=10000, input_length=100)(input)
x = tf.keras.layers.LSTM(32)(x)
x = tf.keras.layers.Dense(64, activation='relu')(x)
x = tf.keras.layers.Dense(64, activation='relu')(x)
x = tf.keras.layers.Dense(64, activation='relu')(x)
output = tf.keras.layers.Dense(1, activation='sigmoid', name='output')(x)
model = tf.keras.Model(inputs=[input], outputs=[output])
dot_img_file = '/tmp/model_1.png'
tf.keras.utils.plot_model(model, to_file=dot_img_file, show_shapes=True)

Arguments

  • model: A Keras model instance
  • to_file: File name of the plot image.
  • show_shapes: whether to display shape information.
  • show_dtype: whether to display layer dtypes.
  • show_layer_names: whether to display layer names.
  • rankdir: rankdir argument passed to PyDot, a string specifying the format of the plot: 'TB' creates a vertical plot; 'LR' creates a horizontal plot.
  • expand_nested: Whether to expand nested models into clusters.
  • dpi: Dots per inch.

Returns

A Jupyter notebook Image object if Jupyter is installed. This enables in-line display of the model plots in notebooks.


model_to_dot function

tf.keras.utils.model_to_dot(
    model,
    show_shapes=False,
    show_dtype=False,
    show_layer_names=True,
    rankdir="TB",
    expand_nested=False,
    dpi=96,
    subgraph=False,
)

Convert a Keras model to dot format.

Arguments

  • model: A Keras model instance.
  • show_shapes: whether to display shape information.
  • show_dtype: whether to display layer dtypes.
  • show_layer_names: whether to display layer names.
  • rankdir: rankdir argument passed to PyDot, a string specifying the format of the plot: 'TB' creates a vertical plot; 'LR' creates a horizontal plot.
  • expand_nested: whether to expand nested models into clusters.
  • dpi: Dots per inch.
  • subgraph: whether to return a pydot.Cluster instance.

Returns

A pydot.Dot instance representing the Keras model or a pydot.Cluster instance representing nested model if subgraph=True.

Raises

  • ImportError: if graphviz or pydot are not available.