KerasRS / API documentation / Retrieval Layers / Retrieval layer

Retrieval layer

[source]

Retrieval class

keras_rs.layers.Retrieval(k: int = 10, return_scores: bool = True, **kwargs: Any)

Retrieval base abstract class.

This layer provides a common interface for all retrieval layers. In order to implement a custom retrieval layer, this abstract class should be subclassed.

Arguments

  • k: int. Number of candidates to retrieve.
  • return_scores: bool. When True, this layer returns a tuple with the top scores and the top identifiers. When False, this layer returns a single tensor with the top identifiers.

[source]

call method

Retrieval.call(inputs: Any)

Returns the top candidates for the query passed as input.

Arguments

  • inputs: the query for which to return top candidates.

Returns

A tuple with the top scores and the top identifiers if returns_scores is True, otherwise a tensor with the top identifiers.


[source]

update_candidates method

Retrieval.update_candidates(
    candidate_embeddings: Any, candidate_ids: Optional[Any] = None
)

Update the set of candidates and optionally their candidate IDs.

Arguments

  • candidate_embeddings: The candidate embeddings.
  • candidate_ids: The identifiers for the candidates. If None, the indices of the candidates are returned instead.

[source]

compute_score method

Retrieval.compute_score(query_embedding: Any, candidate_embedding: Any)

Computes the standard dot product score from queries and candidates.

Arguments

  • query_embedding: Tensor of query embedding corresponding to the queries for which to retrieve top candidates.
  • candidate_embedding: Tensor of candidate embeddings.

Returns

The dot product of queries and candidates.