KerasRS / API documentation / Retrieval Layers / SamplingProbabilityCorrection layer

SamplingProbabilityCorrection layer

[source]

SamplingProbabilityCorrection class

keras_rs.layers.SamplingProbabilityCorrection(epsilon: float = 1e-06, **kwargs: Any)

Sampling probability correction.

Corrects the logits to reflect the sampling probability of negatives.

Arguments

  • epsilon: float. Small float added to sampling probability to avoid taking the log of zero. Defaults to 1e-6.
  • **kwargs: Args to pass to the base class.

Example

# Create the layer.
sampling_probability_correction = (
    keras_rs.layers.SamplingProbabilityCorrection()
)

# Correct the logits based on the provided candidate sampling probability.
logits = sampling_probability_correction(logits, probabilities)

[source]

call method

SamplingProbabilityCorrection.call(logits: Any, candidate_sampling_probability: Any)

Corrects input logits to account for candidate sampling probability.

Arguments

  • logits: The logits tensor to correct, typically [batch_size, num_candidates] but can have more dimensions or be 1D as [num_candidates].
  • candidate_sampling_probability: The sampling probability with the same shape as logits.

Returns

The corrected logits with the same shape as the input logits.