BinaryCrossentropy
classtf_keras.losses.BinaryCrossentropy(
from_logits=False,
label_smoothing=0.0,
axis=-1,
reduction="auto",
name="binary_crossentropy",
)
Computes the cross-entropy loss between true labels and predicted labels.
Use this cross-entropy loss for binary (0 or 1) classification applications. The loss function requires the following inputs:
y_true
(true label): This is either 0 or 1.y_pred
(predicted value): This is the model's prediction, i.e, a single
floating-point value which either represents a
logit, (i.e, value in [-inf, inf]
when from_logits=True
) or a probability (i.e, value in [0., 1.] when
from_logits=False
).Recommended Usage: (set from_logits=True
)
With tf.keras
API:
model.compile(
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
....
)
As a standalone function:
>>> # Example 1: (batch_size = 1, number of samples = 4)
>>> y_true = [0, 1, 0, 0]
>>> y_pred = [-18.6, 0.51, 2.94, -12.8]
>>> bce = tf.keras.losses.BinaryCrossentropy(from_logits=True)
>>> bce(y_true, y_pred).numpy()
0.865
>>> # Example 2: (batch_size = 2, number of samples = 4)
>>> y_true = [[0, 1], [0, 0]]
>>> y_pred = [[-18.6, 0.51], [2.94, -12.8]]
>>> # Using default 'auto'/'sum_over_batch_size' reduction type.
>>> bce = tf.keras.losses.BinaryCrossentropy(from_logits=True)
>>> bce(y_true, y_pred).numpy()
0.865
>>> # Using 'sample_weight' attribute
>>> bce(y_true, y_pred, sample_weight=[0.8, 0.2]).numpy()
0.243
>>> # Using 'sum' reduction` type.
>>> bce = tf.keras.losses.BinaryCrossentropy(from_logits=True,
... reduction=tf.keras.losses.Reduction.SUM)
>>> bce(y_true, y_pred).numpy()
1.730
>>> # Using 'none' reduction type.
>>> bce = tf.keras.losses.BinaryCrossentropy(from_logits=True,
... reduction=tf.keras.losses.Reduction.NONE)
>>> bce(y_true, y_pred).numpy()
array([0.235, 1.496], dtype=float32)
Default Usage: (set from_logits=False
)
>>> # Make the following updates to the above "Recommended Usage" section
>>> # 1. Set `from_logits=False`
>>> tf.keras.losses.BinaryCrossentropy() # OR ...('from_logits=False')
>>> # 2. Update `y_pred` to use probabilities instead of logits
>>> y_pred = [0.6, 0.3, 0.2, 0.8] # OR [[0.6, 0.3], [0.2, 0.8]]
CategoricalCrossentropy
classtf_keras.losses.CategoricalCrossentropy(
from_logits=False,
label_smoothing=0.0,
axis=-1,
reduction="auto",
name="categorical_crossentropy",
)
Computes the crossentropy loss between the labels and predictions.
Use this crossentropy loss function when there are two or more label
classes. We expect labels to be provided in a one_hot
representation. If
you want to provide labels as integers, please use
SparseCategoricalCrossentropy
loss. There should be # classes
floating
point values per feature.
In the snippet below, there is # classes
floating pointing values per
example. The shape of both y_pred
and y_true
are
[batch_size, num_classes]
.
Standalone usage:
>>> y_true = [[0, 1, 0], [0, 0, 1]]
>>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
>>> # Using 'auto'/'sum_over_batch_size' reduction type.
>>> cce = tf.keras.losses.CategoricalCrossentropy()
>>> cce(y_true, y_pred).numpy()
1.177
>>> # Calling with 'sample_weight'.
>>> cce(y_true, y_pred, sample_weight=tf.constant([0.3, 0.7])).numpy()
0.814
>>> # Using 'sum' reduction type.
>>> cce = tf.keras.losses.CategoricalCrossentropy(
... reduction=tf.keras.losses.Reduction.SUM)
>>> cce(y_true, y_pred).numpy()
2.354
>>> # Using 'none' reduction type.
>>> cce = tf.keras.losses.CategoricalCrossentropy(
... reduction=tf.keras.losses.Reduction.NONE)
>>> cce(y_true, y_pred).numpy()
array([0.0513, 2.303], dtype=float32)
Usage with the compile()
API:
model.compile(optimizer='sgd',
loss=tf.keras.losses.CategoricalCrossentropy())
SparseCategoricalCrossentropy
classtf_keras.losses.SparseCategoricalCrossentropy(
from_logits=False,
ignore_class=None,
reduction="auto",
name="sparse_categorical_crossentropy",
)
Computes the crossentropy loss between the labels and predictions.
Use this crossentropy loss function when there are two or more label
classes. We expect labels to be provided as integers. If you want to
provide labels using one-hot
representation, please use
CategoricalCrossentropy
loss. There should be # classes
floating point
values per feature for y_pred
and a single floating point value per
feature for y_true
.
In the snippet below, there is a single floating point value per example for
y_true
and # classes
floating pointing values per example for y_pred
.
The shape of y_true
is [batch_size]
and the shape of y_pred
is
[batch_size, num_classes]
.
Standalone usage:
>>> y_true = [1, 2]
>>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
>>> # Using 'auto'/'sum_over_batch_size' reduction type.
>>> scce = tf.keras.losses.SparseCategoricalCrossentropy()
>>> scce(y_true, y_pred).numpy()
1.177
>>> # Calling with 'sample_weight'.
>>> scce(y_true, y_pred, sample_weight=tf.constant([0.3, 0.7])).numpy()
0.814
>>> # Using 'sum' reduction type.
>>> scce = tf.keras.losses.SparseCategoricalCrossentropy(
... reduction=tf.keras.losses.Reduction.SUM)
>>> scce(y_true, y_pred).numpy()
2.354
>>> # Using 'none' reduction type.
>>> scce = tf.keras.losses.SparseCategoricalCrossentropy(
... reduction=tf.keras.losses.Reduction.NONE)
>>> scce(y_true, y_pred).numpy()
array([0.0513, 2.303], dtype=float32)
Usage with the compile()
API:
model.compile(optimizer='sgd',
loss=tf.keras.losses.SparseCategoricalCrossentropy())
Poisson
classtf_keras.losses.Poisson(reduction="auto", name="poisson")
Computes the Poisson loss between y_true
& y_pred
.
loss = y_pred - y_true * log(y_pred)
Standalone usage:
>>> y_true = [[0., 1.], [0., 0.]]
>>> y_pred = [[1., 1.], [0., 0.]]
>>> # Using 'auto'/'sum_over_batch_size' reduction type.
>>> p = tf.keras.losses.Poisson()
>>> p(y_true, y_pred).numpy()
0.5
>>> # Calling with 'sample_weight'.
>>> p(y_true, y_pred, sample_weight=[0.8, 0.2]).numpy()
0.4
>>> # Using 'sum' reduction type.
>>> p = tf.keras.losses.Poisson(
... reduction=tf.keras.losses.Reduction.SUM)
>>> p(y_true, y_pred).numpy()
0.999
>>> # Using 'none' reduction type.
>>> p = tf.keras.losses.Poisson(
... reduction=tf.keras.losses.Reduction.NONE)
>>> p(y_true, y_pred).numpy()
array([0.999, 0.], dtype=float32)
Usage with the compile()
API:
model.compile(optimizer='sgd', loss=tf.keras.losses.Poisson())
binary_crossentropy
functiontf_keras.losses.binary_crossentropy(
y_true, y_pred, from_logits=False, label_smoothing=0.0, axis=-1
)
Computes the binary crossentropy loss.
Standalone usage:
>>> y_true = [[0, 1], [0, 0]]
>>> y_pred = [[0.6, 0.4], [0.4, 0.6]]
>>> loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
>>> assert loss.shape == (2,)
>>> loss.numpy()
array([0.916 , 0.714], dtype=float32)
Arguments
[batch_size, d0, .. dN]
.[batch_size, d0, .. dN]
.y_pred
is expected to be a logits tensor. By
default, we assume that y_pred
encodes a probability distribution.0
then smooth the labels by
squeezing them towards 0.5 That is, using
1. - 0.5 * label_smoothing
for the target class and
0.5 * label_smoothing
for the non-target class.Returns
Binary crossentropy loss value. shape = [batch_size, d0, .. dN-1]
.
categorical_crossentropy
functiontf_keras.losses.categorical_crossentropy(
y_true, y_pred, from_logits=False, label_smoothing=0.0, axis=-1
)
Computes the categorical crossentropy loss.
Standalone usage:
>>> y_true = [[0, 1, 0], [0, 0, 1]]
>>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
>>> loss = tf.keras.losses.categorical_crossentropy(y_true, y_pred)
>>> assert loss.shape == (2,)
>>> loss.numpy()
array([0.0513, 2.303], dtype=float32)
Arguments
y_pred
is expected to be a logits tensor. By
default, we assume that y_pred
encodes a probability distribution.0
then smooth the labels. For
example, if 0.1
, use 0.1 / num_classes
for non-target labels
and 0.9 + 0.1 / num_classes
for target labels.Returns
Categorical crossentropy loss value.
sparse_categorical_crossentropy
functiontf_keras.losses.sparse_categorical_crossentropy(
y_true, y_pred, from_logits=False, axis=-1, ignore_class=None
)
Computes the sparse categorical crossentropy loss.
Standalone usage:
>>> y_true = [1, 2]
>>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
>>> loss = tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred)
>>> assert loss.shape == (2,)
>>> loss.numpy()
array([0.0513, 2.303], dtype=float32)
>>> y_true = [[[ 0, 2],
... [-1, -1]],
... [[ 0, 2],
... [-1, -1]]]
>>> y_pred = [[[[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]],
... [[0.2, 0.5, 0.3], [0.0, 1.0, 0.0]]],
... [[[1.0, 0.0, 0.0], [0.0, 0.5, 0.5]],
... [[0.2, 0.5, 0.3], [0.0, 1.0, 0.0]]]]
>>> loss = tf.keras.losses.sparse_categorical_crossentropy(
... y_true, y_pred, ignore_class=-1)
>>> loss.numpy()
array([[[2.3841855e-07, 2.3841855e-07],
[0.0000000e+00, 0.0000000e+00]],
[[2.3841855e-07, 6.9314730e-01],
[0.0000000e+00, 0.0000000e+00]]], dtype=float32)
Arguments
y_pred
is expected to be a logits tensor. By
default, we assume that y_pred
encodes a probability distribution.ignore_class=None
), all classes are
considered.Returns
Sparse categorical crossentropy loss value.
poisson
functiontf_keras.losses.poisson(y_true, y_pred)
Computes the Poisson loss between y_true and y_pred.
The Poisson loss is the mean of the elements of the Tensor
y_pred - y_true * log(y_pred)
.
Standalone usage:
>>> y_true = np.random.randint(0, 2, size=(2, 3))
>>> y_pred = np.random.random(size=(2, 3))
>>> loss = tf.keras.losses.poisson(y_true, y_pred)
>>> assert loss.shape == (2,)
>>> y_pred = y_pred + 1e-7
>>> assert np.allclose(
... loss.numpy(), np.mean(y_pred - y_true * np.log(y_pred), axis=-1),
... atol=1e-5)
Arguments
[batch_size, d0, .. dN]
.[batch_size, d0, .. dN]
.Returns
Poisson loss value. shape = [batch_size, d0, .. dN-1]
.
Raises
y_true
and y_pred
have incompatible shapes.KLDivergence
classtf_keras.losses.KLDivergence(reduction="auto", name="kl_divergence")
Computes Kullback-Leibler divergence loss between y_true
& y_pred
.
loss = y_true * log(y_true / y_pred)
See: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence
Standalone usage:
>>> y_true = [[0, 1], [0, 0]]
>>> y_pred = [[0.6, 0.4], [0.4, 0.6]]
>>> # Using 'auto'/'sum_over_batch_size' reduction type.
>>> kl = tf.keras.losses.KLDivergence()
>>> kl(y_true, y_pred).numpy()
0.458
>>> # Calling with 'sample_weight'.
>>> kl(y_true, y_pred, sample_weight=[0.8, 0.2]).numpy()
0.366
>>> # Using 'sum' reduction type.
>>> kl = tf.keras.losses.KLDivergence(
... reduction=tf.keras.losses.Reduction.SUM)
>>> kl(y_true, y_pred).numpy()
0.916
>>> # Using 'none' reduction type.
>>> kl = tf.keras.losses.KLDivergence(
... reduction=tf.keras.losses.Reduction.NONE)
>>> kl(y_true, y_pred).numpy()
array([0.916, -3.08e-06], dtype=float32)
Usage with the compile()
API:
model.compile(optimizer='sgd', loss=tf.keras.losses.KLDivergence())
kl_divergence
functiontf_keras.losses.kl_divergence(y_true, y_pred)
Computes Kullback-Leibler divergence loss between y_true
& y_pred
.
loss = y_true * log(y_true / y_pred)
See: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence
Standalone usage:
>>> y_true = np.random.randint(0, 2, size=(2, 3)).astype(np.float64)
>>> y_pred = np.random.random(size=(2, 3))
>>> loss = tf.keras.losses.kullback_leibler_divergence(y_true, y_pred)
>>> assert loss.shape == (2,)
>>> y_true = tf.keras.backend.clip(y_true, 1e-7, 1)
>>> y_pred = tf.keras.backend.clip(y_pred, 1e-7, 1)
>>> assert np.array_equal(
... loss.numpy(), np.sum(y_true * np.log(y_true / y_pred), axis=-1))
Arguments
Returns
A Tensor
with loss.
Raises
y_true
cannot be cast to the y_pred.dtype
.