Sampler base class

Sampler base class


Sampler class


Base sampler class.


  • temperature: float. optional. Used to control the randomness of the sampling. The higher the temperature, the more diverse the samples. Defaults to 1.0.

Call arguments


This base class can be extended to implement different auto-regressive sampling methods. To do so, override the get_next_token() method, which computes the next token based on a probability distribution over all possible vocab entries.


causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")

# Greedy search with some tokens forbidden.
class CustomSampler(keras_nlp.samplers.Sampler):
    def __init__(self, forbidden_tokens, **kwargs):
        self.forbidden_tokens = forbidden_tokens

    def get_next_token(self, probs):
        batch_size, vocab_size = keras.ops.shape(probs)
        for id in self.forbidden_tokens:
            update = keras.ops.zeros((batch_size, 1))
            probs = keras.ops.slice_update(probs, (0, id), update)
        return keras.ops.argmax(probs, axis=-1)

# 257 = "a" with a leading space, 262 = "the" with a leading space.
causal_lm.compile(sampler=CustomSampler(forbidden_tokens=[257, 262]))
causal_lm.generate(["That's strange"])


get_next_token method


Get the next token. Arguments

  • probabilities: a Tensor, the probability distribution for next token over all vocab tokens.

Get the next token based on given probability distribution over tokens. Subclasses must implement this method.