MeanSquaredError
classtf_keras.losses.MeanSquaredError(reduction="auto", name="mean_squared_error")
Computes the mean of squares of errors between labels and predictions.
loss = mean(square(y_true - y_pred))
Standalone usage:
>>> y_true = [[0., 1.], [0., 0.]]
>>> y_pred = [[1., 1.], [1., 0.]]
>>> # Using 'auto'/'sum_over_batch_size' reduction type.
>>> mse = tf.keras.losses.MeanSquaredError()
>>> mse(y_true, y_pred).numpy()
0.5
>>> # Calling with 'sample_weight'.
>>> mse(y_true, y_pred, sample_weight=[0.7, 0.3]).numpy()
0.25
>>> # Using 'sum' reduction type.
>>> mse = tf.keras.losses.MeanSquaredError(
... reduction=tf.keras.losses.Reduction.SUM)
>>> mse(y_true, y_pred).numpy()
1.0
>>> # Using 'none' reduction type.
>>> mse = tf.keras.losses.MeanSquaredError(
... reduction=tf.keras.losses.Reduction.NONE)
>>> mse(y_true, y_pred).numpy()
array([0.5, 0.5], dtype=float32)
Usage with the compile()
API:
model.compile(optimizer='sgd', loss=tf.keras.losses.MeanSquaredError())
MeanAbsoluteError
classtf_keras.losses.MeanAbsoluteError(reduction="auto", name="mean_absolute_error")
Computes the mean of absolute difference between labels and predictions.
loss = mean(abs(y_true - y_pred))
Standalone usage:
>>> y_true = [[0., 1.], [0., 0.]]
>>> y_pred = [[1., 1.], [1., 0.]]
>>> # Using 'auto'/'sum_over_batch_size' reduction type.
>>> mae = tf.keras.losses.MeanAbsoluteError()
>>> mae(y_true, y_pred).numpy()
0.5
>>> # Calling with 'sample_weight'.
>>> mae(y_true, y_pred, sample_weight=[0.7, 0.3]).numpy()
0.25
>>> # Using 'sum' reduction type.
>>> mae = tf.keras.losses.MeanAbsoluteError(
... reduction=tf.keras.losses.Reduction.SUM)
>>> mae(y_true, y_pred).numpy()
1.0
>>> # Using 'none' reduction type.
>>> mae = tf.keras.losses.MeanAbsoluteError(
... reduction=tf.keras.losses.Reduction.NONE)
>>> mae(y_true, y_pred).numpy()
array([0.5, 0.5], dtype=float32)
Usage with the compile()
API:
model.compile(optimizer='sgd', loss=tf.keras.losses.MeanAbsoluteError())
MeanAbsolutePercentageError
classtf_keras.losses.MeanAbsolutePercentageError(
reduction="auto", name="mean_absolute_percentage_error"
)
Computes the mean absolute percentage error between y_true
& y_pred
.
Formula:
loss = 100 * abs((y_true - y_pred) / y_true)
Note that to avoid dividing by zero, a small epsilon value is added to the denominator.
Standalone usage:
>>> y_true = [[2., 1.], [2., 3.]]
>>> y_pred = [[1., 1.], [1., 0.]]
>>> # Using 'auto'/'sum_over_batch_size' reduction type.
>>> mape = tf.keras.losses.MeanAbsolutePercentageError()
>>> mape(y_true, y_pred).numpy()
50.
>>> # Calling with 'sample_weight'.
>>> mape(y_true, y_pred, sample_weight=[0.7, 0.3]).numpy()
20.
>>> # Using 'sum' reduction type.
>>> mape = tf.keras.losses.MeanAbsolutePercentageError(
... reduction=tf.keras.losses.Reduction.SUM)
>>> mape(y_true, y_pred).numpy()
100.
>>> # Using 'none' reduction type.
>>> mape = tf.keras.losses.MeanAbsolutePercentageError(
... reduction=tf.keras.losses.Reduction.NONE)
>>> mape(y_true, y_pred).numpy()
array([25., 75.], dtype=float32)
Usage with the compile()
API:
model.compile(optimizer='sgd',
loss=tf.keras.losses.MeanAbsolutePercentageError())
MeanSquaredLogarithmicError
classtf_keras.losses.MeanSquaredLogarithmicError(
reduction="auto", name="mean_squared_logarithmic_error"
)
Computes the mean squared logarithmic error between y_true
& y_pred
.
loss = square(log(y_true + 1.) - log(y_pred + 1.))
Standalone usage:
>>> y_true = [[0., 1.], [0., 0.]]
>>> y_pred = [[1., 1.], [1., 0.]]
>>> # Using 'auto'/'sum_over_batch_size' reduction type.
>>> msle = tf.keras.losses.MeanSquaredLogarithmicError()
>>> msle(y_true, y_pred).numpy()
0.240
>>> # Calling with 'sample_weight'.
>>> msle(y_true, y_pred, sample_weight=[0.7, 0.3]).numpy()
0.120
>>> # Using 'sum' reduction type.
>>> msle = tf.keras.losses.MeanSquaredLogarithmicError(
... reduction=tf.keras.losses.Reduction.SUM)
>>> msle(y_true, y_pred).numpy()
0.480
>>> # Using 'none' reduction type.
>>> msle = tf.keras.losses.MeanSquaredLogarithmicError(
... reduction=tf.keras.losses.Reduction.NONE)
>>> msle(y_true, y_pred).numpy()
array([0.240, 0.240], dtype=float32)
Usage with the compile()
API:
model.compile(optimizer='sgd',
loss=tf.keras.losses.MeanSquaredLogarithmicError())
CosineSimilarity
classtf_keras.losses.CosineSimilarity(axis=-1, reduction="auto", name="cosine_similarity")
Computes the cosine similarity between labels and predictions.
Note that it is a number between -1 and 1. When it is a negative number
between -1 and 0, 0 indicates orthogonality and values closer to -1
indicate greater similarity. The values closer to 1 indicate greater
dissimilarity. This makes it usable as a loss function in a setting
where you try to maximize the proximity between predictions and targets.
If either y_true
or y_pred
is a zero vector, cosine similarity will be 0
regardless of the proximity between predictions and targets.
loss = -sum(l2_norm(y_true) * l2_norm(y_pred))
Standalone usage:
>>> y_true = [[0., 1.], [1., 1.]]
>>> y_pred = [[1., 0.], [1., 1.]]
>>> # Using 'auto'/'sum_over_batch_size' reduction type.
>>> cosine_loss = tf.keras.losses.CosineSimilarity(axis=1)
>>> # l2_norm(y_true) = [[0., 1.], [1./1.414, 1./1.414]]
>>> # l2_norm(y_pred) = [[1., 0.], [1./1.414, 1./1.414]]
>>> # l2_norm(y_true) . l2_norm(y_pred) = [[0., 0.], [0.5, 0.5]]
>>> # loss = mean(sum(l2_norm(y_true) . l2_norm(y_pred), axis=1))
>>> # = -((0. + 0.) + (0.5 + 0.5)) / 2
>>> cosine_loss(y_true, y_pred).numpy()
-0.5
>>> # Calling with 'sample_weight'.
>>> cosine_loss(y_true, y_pred, sample_weight=[0.8, 0.2]).numpy()
-0.0999
>>> # Using 'sum' reduction type.
>>> cosine_loss = tf.keras.losses.CosineSimilarity(axis=1,
... reduction=tf.keras.losses.Reduction.SUM)
>>> cosine_loss(y_true, y_pred).numpy()
-0.999
>>> # Using 'none' reduction type.
>>> cosine_loss = tf.keras.losses.CosineSimilarity(axis=1,
... reduction=tf.keras.losses.Reduction.NONE)
>>> cosine_loss(y_true, y_pred).numpy()
array([-0., -0.999], dtype=float32)
Usage with the compile()
API:
model.compile(optimizer='sgd',
loss=tf.keras.losses.CosineSimilarity(axis=1))
Arguments
tf.keras.losses.Reduction
to apply to loss.
Default value is AUTO
. AUTO
indicates that the reduction option
will be determined by the usage context. For almost all cases this
defaults to SUM_OVER_BATCH_SIZE
. When used under a
tf.distribute.Strategy
, except via Model.compile()
and
Model.fit()
, using AUTO
or SUM_OVER_BATCH_SIZE
will raise an
error. Please see this custom training tutorial
for more details.mean_squared_error
functiontf_keras.losses.mean_squared_error(y_true, y_pred)
Computes the mean squared error between labels and predictions.
After computing the squared distance between the inputs, the mean value over the last dimension is returned.
loss = mean(square(y_true - y_pred), axis=-1)
Standalone usage:
>>> y_true = np.random.randint(0, 2, size=(2, 3))
>>> y_pred = np.random.random(size=(2, 3))
>>> loss = tf.keras.losses.mean_squared_error(y_true, y_pred)
>>> assert loss.shape == (2,)
>>> assert np.array_equal(
... loss.numpy(), np.mean(np.square(y_true - y_pred), axis=-1))
Arguments
[batch_size, d0, .. dN]
.[batch_size, d0, .. dN]
.Returns
Mean squared error values. shape = [batch_size, d0, .. dN-1]
.
mean_absolute_error
functiontf_keras.losses.mean_absolute_error(y_true, y_pred)
Computes the mean absolute error between labels and predictions.
loss = mean(abs(y_true - y_pred), axis=-1)
Standalone usage:
>>> y_true = np.random.randint(0, 2, size=(2, 3))
>>> y_pred = np.random.random(size=(2, 3))
>>> loss = tf.keras.losses.mean_absolute_error(y_true, y_pred)
>>> assert loss.shape == (2,)
>>> assert np.array_equal(
... loss.numpy(), np.mean(np.abs(y_true - y_pred), axis=-1))
Arguments
[batch_size, d0, .. dN]
.[batch_size, d0, .. dN]
.Returns
Mean absolute error values. shape = [batch_size, d0, .. dN-1]
.
mean_absolute_percentage_error
functiontf_keras.losses.mean_absolute_percentage_error(y_true, y_pred)
Computes the mean absolute percentage error between y_true
& y_pred
.
loss = 100 * mean(abs((y_true - y_pred) / y_true), axis=-1)
Standalone usage:
>>> y_true = np.random.random(size=(2, 3))
>>> y_true = np.maximum(y_true, 1e-7) # Prevent division by zero
>>> y_pred = np.random.random(size=(2, 3))
>>> loss = tf.keras.losses.mean_absolute_percentage_error(y_true, y_pred)
>>> assert loss.shape == (2,)
>>> assert np.array_equal(
... loss.numpy(),
... 100. * np.mean(np.abs((y_true - y_pred) / y_true), axis=-1))
Arguments
[batch_size, d0, .. dN]
.[batch_size, d0, .. dN]
.Returns
Mean absolute percentage error values. shape = [batch_size, d0, ..
dN-1]
.
mean_squared_logarithmic_error
functiontf_keras.losses.mean_squared_logarithmic_error(y_true, y_pred)
Computes the mean squared logarithmic error between y_true
& y_pred
.
loss = mean(square(log(y_true + 1) - log(y_pred + 1)), axis=-1)
Standalone usage:
>>> y_true = np.random.randint(0, 2, size=(2, 3))
>>> y_pred = np.random.random(size=(2, 3))
>>> loss = tf.keras.losses.mean_squared_logarithmic_error(y_true, y_pred)
>>> assert loss.shape == (2,)
>>> y_true = np.maximum(y_true, 1e-7)
>>> y_pred = np.maximum(y_pred, 1e-7)
>>> assert np.allclose(
... loss.numpy(),
... np.mean(
... np.square(np.log(y_true + 1.) - np.log(y_pred + 1.)), axis=-1))
Arguments
[batch_size, d0, .. dN]
.[batch_size, d0, .. dN]
.Returns
Mean squared logarithmic error values. shape = [batch_size, d0, ..
dN-1]
.
cosine_similarity
functiontf_keras.losses.cosine_similarity(y_true, y_pred, axis=-1)
Computes the cosine similarity between labels and predictions.
Note that it is a number between -1 and 1. When it is a negative number
between -1 and 0, 0 indicates orthogonality and values closer to -1
indicate greater similarity. The values closer to 1 indicate greater
dissimilarity. This makes it usable as a loss function in a setting
where you try to maximize the proximity between predictions and
targets. If either y_true
or y_pred
is a zero vector, cosine
similarity will be 0 regardless of the proximity between predictions
and targets.
loss = -sum(l2_norm(y_true) * l2_norm(y_pred))
Standalone usage:
>>> y_true = [[0., 1.], [1., 1.], [1., 1.]]
>>> y_pred = [[1., 0.], [1., 1.], [-1., -1.]]
>>> loss = tf.keras.losses.cosine_similarity(y_true, y_pred, axis=1)
>>> loss.numpy()
array([-0., -0.999, 0.999], dtype=float32)
Arguments
Returns
Cosine similarity tensor.
Huber
classtf_keras.losses.Huber(delta=1.0, reduction="auto", name="huber_loss")
Computes the Huber loss between y_true
& y_pred
.
For each value x in error = y_true - y_pred
:
loss = 0.5 * x^2 if |x| <= d
loss = 0.5 * d^2 + d * (|x| - d) if |x| > d
where d is delta
. See: https://en.wikipedia.org/wiki/Huber_loss
Standalone usage:
>>> y_true = [[0, 1], [0, 0]]
>>> y_pred = [[0.6, 0.4], [0.4, 0.6]]
>>> # Using 'auto'/'sum_over_batch_size' reduction type.
>>> h = tf.keras.losses.Huber()
>>> h(y_true, y_pred).numpy()
0.155
>>> # Calling with 'sample_weight'.
>>> h(y_true, y_pred, sample_weight=[1, 0]).numpy()
0.09
>>> # Using 'sum' reduction type.
>>> h = tf.keras.losses.Huber(
... reduction=tf.keras.losses.Reduction.SUM)
>>> h(y_true, y_pred).numpy()
0.31
>>> # Using 'none' reduction type.
>>> h = tf.keras.losses.Huber(
... reduction=tf.keras.losses.Reduction.NONE)
>>> h(y_true, y_pred).numpy()
array([0.18, 0.13], dtype=float32)
Usage with the compile()
API:
model.compile(optimizer='sgd', loss=tf.keras.losses.Huber())
huber
functiontf_keras.losses.huber(y_true, y_pred, delta=1.0)
Computes Huber loss value.
For each value x in error = y_true - y_pred
:
loss = 0.5 * x^2 if |x| <= d
loss = d * |x| - 0.5 * d^2 if |x| > d
where d is delta
. See: https://en.wikipedia.org/wiki/Huber_loss
Arguments
Returns
Tensor with one scalar loss entry per sample.
LogCosh
classtf_keras.losses.LogCosh(reduction="auto", name="log_cosh")
Computes the logarithm of the hyperbolic cosine of the prediction error.
logcosh = log((exp(x) + exp(-x))/2)
,
where x is the error y_pred - y_true
.
Standalone usage:
>>> y_true = [[0., 1.], [0., 0.]]
>>> y_pred = [[1., 1.], [0., 0.]]
>>> # Using 'auto'/'sum_over_batch_size' reduction type.
>>> l = tf.keras.losses.LogCosh()
>>> l(y_true, y_pred).numpy()
0.108
>>> # Calling with 'sample_weight'.
>>> l(y_true, y_pred, sample_weight=[0.8, 0.2]).numpy()
0.087
>>> # Using 'sum' reduction type.
>>> l = tf.keras.losses.LogCosh(
... reduction=tf.keras.losses.Reduction.SUM)
>>> l(y_true, y_pred).numpy()
0.217
>>> # Using 'none' reduction type.
>>> l = tf.keras.losses.LogCosh(
... reduction=tf.keras.losses.Reduction.NONE)
>>> l(y_true, y_pred).numpy()
array([0.217, 0.], dtype=float32)
Usage with the compile()
API:
model.compile(optimizer='sgd', loss=tf.keras.losses.LogCosh())
log_cosh
functiontf_keras.losses.log_cosh(y_true, y_pred)
Logarithm of the hyperbolic cosine of the prediction error.
log(cosh(x))
is approximately equal to (x ** 2) / 2
for small x
and
to abs(x) - log(2)
for large x
. This means that 'logcosh' works mostly
like the mean squared error, but will not be so strongly affected by the
occasional wildly incorrect prediction.
Standalone usage:
>>> y_true = np.random.random(size=(2, 3))
>>> y_pred = np.random.random(size=(2, 3))
>>> loss = tf.keras.losses.logcosh(y_true, y_pred)
>>> assert loss.shape == (2,)
>>> x = y_pred - y_true
>>> assert np.allclose(
... loss.numpy(),
... np.mean(x + np.log(np.exp(-2. * x) + 1.) - tf.math.log(2.),
... axis=-1),
... atol=1e-5)
Arguments
[batch_size, d0, .. dN]
.[batch_size, d0, .. dN]
.Returns
Logcosh error values. shape = [batch_size, d0, .. dN-1]
.