A metric is a function that is used to judge the performance of your model.

Metric functions are similar to loss functions, except that the results from evaluating a metric are not used when training the model. Note that you may use any loss function as a metric.

Available metrics

Base Metric class

Accuracy metrics

Probabilistic metrics

Regression metrics

Classification metrics based on True/False positives & negatives

Image segmentation metrics

Hinge metrics for "maximum-margin" classification

Metric wrappers and reduction metrics

Usage with compile() & fit()

The compile() method takes a metrics argument, which is a list of metrics:


Metric values are displayed during fit() and logged to the History object returned by fit(). They are also returned by model.evaluate().

Note that the best way to monitor your metrics during training is via TensorBoard.

To track metrics under a specific name, you can pass the name argument to the metric constructor:


All built-in metrics may also be passed via their string identifier (in this case, default constructor argument values are used, including a default metric name):


Standalone usage

Unlike losses, metrics are stateful. You update their state using the update_state() method, and you query the scalar metric result using the result() method:

m = keras.metrics.AUC()
m.update_state([0, 1, 1, 1], [0, 1, 0, 0])
print('Intermediate result:', float(m.result()))

m.update_state([1, 1, 1, 1], [0, 1, 1, 0])
print('Final result:', float(m.result()))

The internal state can be cleared via metric.reset_states().

Here's how you would use a metric as part of a simple custom training loop:

accuracy = keras.metrics.CategoricalAccuracy()
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)
optimizer = keras.optimizers.Adam()

# Iterate over the batches of a dataset.
for step, (x, y) in enumerate(dataset):
    with tf.GradientTape() as tape:
        logits = model(x)
        # Compute the loss value for this batch.
        loss_value = loss_fn(y, logits)

    # Update the state of the `accuracy` metric.
    accuracy.update_state(y, logits)

    # Update the weights of the model to minimize the loss value.
    gradients = tape.gradient(loss_value, model.trainable_weights)
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))

    # Logging the current accuracy value so far.
    if step % 100 == 0:
        print('Step:', step)        
        print('Total running accuracy so far: %.3f' % accuracy.result())

Creating custom metrics

As simple callables (stateless)

Much like loss functions, any callable with signature metric_fn(y_true, y_pred) that returns an array of losses (one of sample in the input batch) can be passed to compile() as a metric. Note that sample weighting is automatically supported for any such metric.

Here's a simple example:

from keras import ops

def my_metric_fn(y_true, y_pred):
    squared_difference = ops.square(y_true - y_pred)
    return ops.mean(squared_difference, axis=-1)  # Note the `axis=-1`

model.compile(optimizer='adam', loss='mean_squared_error', metrics=[my_metric_fn])

In this case, the scalar metric value you are tracking during training and evaluation is the average of the per-batch metric values for all batches see during a given epoch (or during a given call to model.evaluate()).

As subclasses of Metric (stateful)

Not all metrics can be expressed via stateless callables, because metrics are evaluated for each batch during training and evaluation, but in some cases the average of the per-batch values is not what you are interested in.

Let's say that you want to compute AUC over a given evaluation dataset: the average of the per-batch AUC values isn't the same as the AUC over the entire dataset.

For such metrics, you're going to want to subclass the Metric class, which can maintain a state across batches. It's easy:

  • Create the state variables in __init__
  • Update the variables given y_true and y_pred in update_state()
  • Return the scalar metric result in result()
  • Clear the state in reset_states()

Here's a simple example computing binary true positives:

class BinaryTruePositives(keras.metrics.Metric):

  def __init__(self, name='binary_true_positives', **kwargs):
    super().__init__(name=name, **kwargs)
    self.true_positives = self.add_weight(name='tp', initializer='zeros')

  def update_state(self, y_true, y_pred, sample_weight=None):
    y_true = ops.cast(y_true, "bool")
    y_pred = ops.cast(y_pred, "bool")

    values = ops.logical_and(ops.equal(y_true, True), ops.equal(y_pred, True))
    values = ops.cast(values, self.dtype)
    if sample_weight is not None:
      sample_weight = ops.cast(sample_weight, self.dtype)
      values = values * sample_weight

  def result(self):
    return self.true_positives

  def reset_states(self):

m = BinaryTruePositives()
m.update_state([0, 1, 1, 1], [0, 1, 0, 0])
print(f'Intermediate result: {m.result().numpy()}')

m.update_state([1, 1, 1, 1], [0, 1, 1, 0])
print(f'Intermediate result: {m.result().numpy()}')