SamplingProbabilityCorrection
classkeras_rs.layers.SamplingProbabilityCorrection(epsilon: float = 1e-06, **kwargs: Any)
Sampling probability correction.
Corrects the logits to reflect the sampling probability of negatives.
Arguments
Example
# Create the layer.
sampling_probability_correction = (
keras_rs.layers.SamplingProbabilityCorrection()
)
# Correct the logits based on the provided candidate sampling probability.
logits = sampling_probability_correction(logits, probabilities)
call
methodSamplingProbabilityCorrection.call(logits: Any, candidate_sampling_probability: Any)
Corrects input logits to account for candidate sampling probability.
Arguments
[batch_size, num_candidates]
but can have more dimensions or
be 1D as [num_candidates]
.logits
.Returns
The corrected logits with the same shape as the input logits.