BruteForceRetrieval
classkeras_rs.layers.BruteForceRetrieval(
candidate_embeddings: Optional[Any] = None,
candidate_ids: Optional[Any] = None,
k: int = 10,
return_scores: bool = True,
**kwargs: Any
)
Brute force top-k retrieval.
This layer maintains a set of candidates and is able to exactly retrieve the top-k candidates for a given query. It does this by computing the scores for all of the candidates for the query and extracting the top ones. The returned top-k candidates are sorted by score.
By default, this layer returns a tuple with the top scores and top identifiers, but it can be configured to return a single tensor with the top identifiers.
The identifiers for the candidates can be specified as a tensor. If not provided, the IDs used are simply the candidate indices.
Note that the serialization of this layer does not preserve the candidates
and only saves the k
and return_scores
arguments. One has to call
update_candidates
after deserializing the layers.
Arguments
None
,
candidates must be provided using update_candidates
before
using this layer.None
the
indices of the candidates are returned instead.True
, this layer returns a tuple with the top
scores and the top identifiers. When False
, this layer returns
a single tensor with the top identifiers.Example
retrieval = keras_rs.layers.BruteForceRetrieval(k=100)
# At some later point, we update the candidates.
retrieval.update_candidates(candidate_embeddings, candidate_ids)
# We can then retrieve the top candidates for any number of queries.
# Scores are stored highest first. Scores correspond to ids in the same row.
tops_scores, top_ids = retrieval(query_embeddings)
call
methodBruteForceRetrieval.call(inputs: Any)
Returns the top candidates for the query passed as input.
Arguments
Returns
A tuple with the top scores and the top identifiers if
returns_scores
is True, otherwise a tensor with the top
identifiers.
update_candidates
methodBruteForceRetrieval.update_candidates(
candidate_embeddings: Any, candidate_ids: Optional[Any] = None
)
Update the set of candidates and optionally their candidate IDs.
Arguments
None
, the
indices of the candidates are returned instead.