Retrieval
classkeras_rs.layers.Retrieval(k: int = 10, return_scores: bool = True, **kwargs: Any)
Retrieval base abstract class.
This layer provides a common interface for all retrieval layers. In order to implement a custom retrieval layer, this abstract class should be subclassed.
Arguments
True
, this layer returns a tuple with the
top scores and the top identifiers. When False
, this layer returns
a single tensor with the top identifiers.call
methodRetrieval.call(inputs: Any)
Returns the top candidates for the query passed as input.
Arguments
Returns
A tuple with the top scores and the top identifiers if
returns_scores
is True, otherwise a tensor with the top
identifiers.
update_candidates
methodRetrieval.update_candidates(
candidate_embeddings: Any, candidate_ids: Optional[Any] = None
)
Update the set of candidates and optionally their candidate IDs.
Arguments
None
, the
indices of the candidates are returned instead.compute_score
methodRetrieval.compute_score(query_embedding: Any, candidate_embedding: Any)
Computes the standard dot product score from queries and candidates.
Arguments
Returns
The dot product of queries and candidates.