ยป Keras API reference / KerasNLP / Metrics / Perplexity metric

Perplexity metric

[source]

Perplexity class

keras_nlp.metrics.Perplexity(
    from_logits=False, mask_token_id=None, dtype=None, name="perplexity", **kwargs
)

Perplexity metric.

This class implements the perplexity metric. In short, this class calculates the cross entropy loss and takes its exponent. Note: This implementation is not suitable for fixed-size windows.

Arguments

  • from_logits: bool. If True, y_pred (input to update_state()) should be the logits as returned by the model. Otherwise, y_pred is a tensor of probabilities.
  • mask_token_id: int. ID of the token to be masked. If provided, the mask is computed for this class. Note that if this field is provided, and if the sample_weight field in update_state() is also provided, we will compute the final sample_weight as the element-wise product of the mask and the sample_weight.
  • dtype: string or tf.dtypes.Dtype. Precision of metric computation. If not specified, it defaults to tf.float32.
  • name: string. Name of the metric instance.
  • **kwargs: Other keyword arguments.

Examples

  1. Calculate perplexity by calling update_state() and result(). 1.1. sample_weight, and mask_token_id are not provided.
>>> tf.random.set_seed(42)
>>> perplexity = keras_nlp.metrics.Perplexity(name="perplexity")
>>> target = tf.random.uniform(
...     shape=[2, 5],  maxval=10, dtype=tf.int32, seed=42)
>>> logits = tf.random.uniform(shape=(2, 5, 10), seed=42)
>>> perplexity.update_state(target, logits)
>>> perplexity.result()
<tf.Tensor: shape=(), dtype=float32, numpy=11.8781595>

1.2. sample_weight specified (masking token with ID 0).

>>> tf.random.set_seed(42)
>>> perplexity = keras_nlp.metrics.Perplexity(name="perplexity")
>>> target = tf.random.uniform(
...     shape=[2, 5],  maxval=10, dtype=tf.int32, seed=42)
>>> logits = tf.random.uniform(shape=(2, 5, 10), seed=42)
>>> sample_weight = tf.cast(
...     tf.math.logical_not(tf.equal(target, 0)), tf.float32)
>>> perplexity.update_state(target, logits, sample_weight)
>>> perplexity.result()
<tf.Tensor: shape=(), dtype=float32, numpy=13.1128>
  1. Call perplexity directly.
>>> tf.random.set_seed(42)
>>> perplexity = keras_nlp.metrics.Perplexity(name="perplexity")
>>> target = tf.random.uniform(
...     shape=[2, 5],  maxval=10, dtype=tf.int32, seed=42)
>>> logits = tf.random.uniform(shape=(2, 5, 10), seed=42)
>>> perplexity(target, logits)
<tf.Tensor: shape=(), dtype=float32, numpy=11.8781595>
  1. Provide the padding token ID and let the class compute the mask on its own.
>>> tf.random.set_seed(42)
>>> perplexity = keras_nlp.metrics.Perplexity(
...     name="perplexity", mask_token_id=0)
>>> target = tf.random.uniform(
...     shape=[2, 5],  maxval=10, dtype=tf.int32, seed=42)
>>> logits = tf.random.uniform(shape=(2, 5, 10), seed=42)
>>> perplexity(target, logits)
<tf.Tensor: shape=(), dtype=float32, numpy=13.1128>