KerasRS / API documentation / Retrieval Layers / HardNegativeMining layer

HardNegativeMining layer

[source]

HardNegativeMining class

keras_rs.layers.HardNegativeMining(num_hard_negatives: int, **kwargs: Any)

Filter logits and labels to return hard negatives.

The output will include logits and labels for the requested number of hard negatives as well as the positive candidate.

Arguments

  • num_hard_negatives: How many hard negatives to return.
  • **kwargs: Args to pass to the base class.

Example

# Create layer with the configured number of hard negatives to mine.
hard_negative_mining = keras_rs.layers.HardNegativeMining(
    num_hard_negatives=10
)

# This will retrieve the top 10 negative candidates plus the positive
# candidate from `labels` for each row.
out_logits, out_labels = hard_negative_mining(in_logits, in_labels)

[source]

call method

HardNegativeMining.call(logits: Any, labels: Any)

Filters logits and labels with per-query hard negative mining.

The result will include logits and labels for num_hard_negatives negatives as well as the positive candidate.

Arguments

  • logits: The logits tensor, typically [batch_size, num_candidates] but can have more dimensions or be 1D as [num_candidates].
  • labels: The one-hot labels tensor, must be the same shape as logits.

Returns

A tuple containing two tensors with the last dimension of num_candidates replaced with num_hard_negatives + 1.

  • logits: [..., num_hard_negatives + 1] tensor of logits.
  • labels: [..., num_hard_negatives + 1] one-hot tensor of labels.