HardNegativeMining
classkeras_rs.layers.HardNegativeMining(num_hard_negatives: int, **kwargs: Any)
Filter logits and labels to return hard negatives.
The output will include logits and labels for the requested number of hard negatives as well as the positive candidate.
Arguments
Example
# Create layer with the configured number of hard negatives to mine.
hard_negative_mining = keras_rs.layers.HardNegativeMining(
num_hard_negatives=10
)
# This will retrieve the top 10 negative candidates plus the positive
# candidate from `labels` for each row.
out_logits, out_labels = hard_negative_mining(in_logits, in_labels)
call
methodHardNegativeMining.call(logits: Any, labels: Any)
Filters logits and labels with per-query hard negative mining.
The result will include logits and labels for num_hard_negatives
negatives as well as the positive candidate.
Arguments
[batch_size, num_candidates]
but can have more dimensions or be 1D as [num_candidates]
.logits
.Returns
A tuple containing two tensors with the last dimension of
num_candidates
replaced with num_hard_negatives + 1
.
[..., num_hard_negatives + 1]
tensor of logits.[..., num_hard_negatives + 1]
one-hot tensor of labels.