KerasRS / API documentation / Retrieval Layers / BruteForceRetrieval layer

BruteForceRetrieval layer

[source]

BruteForceRetrieval class

keras_rs.layers.BruteForceRetrieval(
    candidate_embeddings: Optional[Any] = None,
    candidate_ids: Optional[Any] = None,
    k: int = 10,
    return_scores: bool = True,
    **kwargs: Any
)

Brute force top-k retrieval.

This layer maintains a set of candidates and is able to exactly retrieve the top-k candidates for a given query. It does this by computing the scores for all of the candidates for the query and extracting the top ones. The returned top-k candidates are sorted by score.

By default, this layer returns a tuple with the top scores and top identifiers, but it can be configured to return a single tensor with the top identifiers.

The identifiers for the candidates can be specified as a tensor. If not provided, the IDs used are simply the candidate indices.

Note that the serialization of this layer does not preserve the candidates and only saves the k and return_scores arguments. One has to call update_candidates after deserializing the layers.

Arguments

  • candidate_embeddings: The candidate embeddings. If None, candidates must be provided using update_candidates before using this layer.
  • candidate_ids: The identifiers for the candidates. If None the indices of the candidates are returned instead.
  • k: Number of candidates to retrieve.
  • return_scores: When True, this layer returns a tuple with the top scores and the top identifiers. When False, this layer returns a single tensor with the top identifiers.
  • **kwargs: Args to pass to the base class.

Example

retrieval = keras_rs.layers.BruteForceRetrieval(k=100)

# At some later point, we update the candidates.
retrieval.update_candidates(candidate_embeddings, candidate_ids)

# We can then retrieve the top candidates for any number of queries.
# Scores are stored highest first. Scores correspond to ids in the same row.
tops_scores, top_ids = retrieval(query_embeddings)

[source]

call method

BruteForceRetrieval.call(inputs: Any)

Returns the top candidates for the query passed as input.

Arguments

  • inputs: the query for which to return top candidates.

Returns

A tuple with the top scores and the top identifiers if returns_scores is True, otherwise a tensor with the top identifiers.


[source]

update_candidates method

BruteForceRetrieval.update_candidates(
    candidate_embeddings: Any, candidate_ids: Optional[Any] = None
)

Update the set of candidates and optionally their candidate IDs.

Arguments

  • candidate_embeddings: The candidate embeddings.
  • candidate_ids: The identifiers for the candidates. If None, the indices of the candidates are returned instead.