ยป Keras API reference / KerasNLP / Utils / beam_search function

beam_search function

[source]

beam_search function

keras_nlp.utils.beam_search(
    token_probability_fn,
    prompt,
    max_length,
    num_beams,
    from_logits=False,
    end_token_id=None,
    pad_token_id=0,
)

Text generation utility based on beam search algorithm.

At each time-step, beam search keeps the beams (sequences) of the top num_beams highest accumulated probabilities, and uses each one of the beams to predict candidate next tokens.

Arguments

  • token_probability_fn: a callable, which takes in input_sequence and output the probability distribution of the next token. If from_logits set to True, it should output the logits of the next token. The input shape would be [batch_size, length] and the output should be [batch_size, vocab_size], where batch_size is variable.
  • prompt: a list or a Tensor, can be 1D or 2D, the initial tokens to append generated tokens. The initial beam for beam search.
  • max_length: int. The max length of generated text.
  • num_beams: int. The number of beams that should be kept at each time-step. num_beams should be strictly positive.
  • from_logits: bool. Indicates whether token_probability_fn outputs logits or probabilities.
  • end_token_id: int, defaults to None. The token marking the end of the sequence, once encountered the generation is finished for the exact sequence. If None, every sequence is generated up to max_length. If set, all tokens after encountering end_token_id will be replaced with pad_token_id.
  • pad_token_id: int, defaults to 0. The pad token after end_token_id is received.

Returns

A 1D int Tensor, or 2D int Tensor representing the generated sequences.

Examples

BATCH_SIZE = 8
VOCAB_SIZE = 10
FEATURE_SIZE = 16
START_ID = 1
END_ID = 2

# Create a dummy model to predict the next token.
model = tf.keras.Sequential(
    [
        tf.keras.Input(shape=[None]),
        tf.keras.layers.Embedding(
            input_dim=VOCAB_SIZE,
            output_dim=FEATURE_SIZE,
        ),
        tf.keras.layers.Dense(VOCAB_SIZE, activation="softmax"),
    ]
)

# Define a function that outputs the next token's probability given the
# input sequence.
def token_probability_fn(inputs):
    return model(inputs)[:, -1, :]

prompt = tf.fill((BATCH_SIZE, 1), START_ID)

# Print the generated sequence (token ids).
keras_nlp.utils.beam_search(
    token_probability_fn,
    prompt,
    max_length=10,
    num_beams=5,
    end_token_id=END_ID,
)