Keras 3 API documentation / KerasNLP / Samplers / GreedySampler

GreedySampler

[source]

GreedySampler class

keras_nlp.samplers.GreedySampler(**kwargs)

Greedy sampler class.

This sampler is implemented on greedy search, i.e., always picking up the token of the largest probability as the next token.

Examples

causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")

# Pass by name to compile.
causal_lm.compile(sampler="greedy")
causal_lm.generate(["Keras is a"])

# Pass by object to compile.
sampler = keras_nlp.samplers.GreedySampler()
causal_lm.compile(sampler=sampler)
causal_lm.generate(["Keras is a"])