ยป Keras API reference / KerasNLP / Utils / greedy_search function

greedy_search function

[source]

greedy_search function

keras_nlp.utils.greedy_search(
    token_probability_fn, prompt, max_length, end_token_id=None, pad_token_id=0
)

Text generation utility based on greedy search.

Greedy search always appends the token having the largest probability to existing sequence.

Arguments

  • token_probability_fn: a callable, which takes in input_sequence and output the probability distribution or the logits of the next token.
  • prompt: a list or a Tensor, can be 1D or 2D, the initial tokens to append generated tokens.
  • max_length: int. The max length of generated text.
  • end_token_id: int, defaults to None. The token marking the end of the sequence, once encountered the generation is finished for the exact sequence. If None, every sequence is generated up to max_length. If set, all tokens after encountering end_token_id will be replaced with pad_token_id.
  • pad_token_id: int, defaults to 0. The pad token after end_token_id is received.

Returns

A 1D int Tensor, or 2D int RaggedTensor representing the generated sequences.

Examples

BATCH_SIZE = 8
VOCAB_SIZE = 10
FEATURE_SIZE = 16
START_ID = 1
END_ID = 2

# Create a dummy model to predict the next token.
model = tf.keras.Sequential(
    [
        tf.keras.Input(shape=[None]),
        tf.keras.layers.Embedding(
            input_dim=VOCAB_SIZE,
            output_dim=FEATURE_SIZE,
        ),
        tf.keras.layers.Dense(VOCAB_SIZE, activation="softmax"),
    ]
)

# Define a function that outputs the next token's probability given the
# input sequence.
def token_probability_fn(inputs):
    return model(inputs)[:, -1, :]

prompt = tf.fill((BATCH_SIZE, 1), START_ID)

# Print the generated sequence (token ids).
keras_nlp.utils.greedy_search(
    token_probability_fn,
    prompt,
    max_length=10,
    end_token_id=END_ID,
)