TopKSampler

[source]

TopKSampler class

keras_nlp.samplers.TopKSampler(k=5, seed=None, **kwargs)

Top-K Sampler class.

This sampler implements top-k search algorithm. Briefly, top-k algorithm randomly selects a token from the tokens of top K probability, with selection chance determined by the probability.

Arguments

  • k: int, the k value of top-k.
  • seed: int. The random seed. Defaults to None.

Call arguments

{{call_args}}

Examples

causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")

# Pass by name to compile.
causal_lm.compile(sampler="top_k")
causal_lm.generate(["Keras is a"])

# Pass by object to compile.
sampler = keras_nlp.samplers.TopKSampler(k=5, temperature=0.7)
causal_lm.compile(sampler=sampler)
causal_lm.generate(["Keras is a"])