» Code examples / Generative Deep Learning / GPT text generation with KerasNLP

GPT text generation with KerasNLP

Author: Jesse Chan
Date created: 2022/07/25
Last modified: 2022/07/25
Description: Using KerasNLP to train a mini-GPT model for text generation.

View in Colab GitHub source


In this example, we will use KerasNLP to build a scaled down Generative Pre-Trained (GPT) model. GPT is a Transformer-based model that allows you to generate sophisticated text from a prompt.

We will train the model on the simplebooks-92 corpus, which is a dataset made from several novels. It is a good dataset for this example since it has a small vocabulary and high word frequency, which is beneficial when training a model with few parameters.

This example combines concepts from Text generation with a miniature GPT with KerasNLP abstractions. We will demonstrate how KerasNLP tokenization, layers and metrics simplify the training process, and then show how to generate output text using the KerasNLP sampling utilities.

Note: If you are running this example on a Colab, make sure to enable GPU runtime for faster training.

This example requires KerasNLP. You can install it via the following command: pip install keras-nlp


import os
import keras_nlp
import tensorflow as tf
from tensorflow import keras

Settings & hyperparameters

# Data
SEQ_LEN = 128

# Model
VOCAB_SIZE = 5000  # Limits parameters in model.

# Training

# Inference

Load the data

Now, let's download the dataset! The SimpleBooks dataset consists of 1,573 Gutenberg books, and has one of the smallest vocabulary size to word-level tokens ratio. It has a vocabulary size of ~98k, a third of WikiText-103's, with around the same number of tokens (~100M). This makes it easy to fit a small model.

dir = os.path.expanduser("~/.keras/datasets/simplebooks/")

# Load simplebooks-92 train set and filter out short lines.
raw_train_ds = (
    tf.data.TextLineDataset(dir + "simplebooks-92-raw/train.txt")
    .filter(lambda x: tf.strings.length(x) > MIN_TRAINING_SEQ_LEN)

# Load simplebooks-92 validation set and filter out short lines.
raw_val_ds = (
    tf.data.TextLineDataset(dir + "simplebooks-92-raw/valid.txt")
    .filter(lambda x: tf.strings.length(x) > MIN_TRAINING_SEQ_LEN)

Train the tokenizer

We train the tokenizer from the training dataset for a vocabulary size of VOCAB_SIZE, which is a tuned hyperparameter. We want to limit the vocabulary as much as possible, as we will see later on that it has a large affect on the number of model parameters. We also don't want to include too few vocabulary terms, or there would be too many out-of-vocabulary (OOV) sub-words. In addition, three tokens are reserved in the vocabulary:

  • "[PAD]" for padding sequences to SEQ_LEN. This token has index 0 in both reserved_tokens and vocab, since WordPieceTokenizer (and other layers) consider 0/vocab[0] as the default padding.
  • "[UNK]" for OOV sub-words, which should match the default oov_token="[UNK]" in WordPieceTokenizer.
  • "[BOS]" stands for beginning of sentence, but here technically it is a token representing the beginning of each line of training data.
# Train tokenizer vocabulary
vocab = keras_nlp.tokenizers.compute_word_piece_vocabulary(
    reserved_tokens=["[PAD]", "[UNK]", "[BOS]"],

Load tokenizer

We use the vocabulary data to initialize keras_nlp.tokenizers.WordPieceTokenizer. WordPieceTokenizer is an efficient implementation of the WordPiece algorithm used by BERT and other models. It will strip, lower-case and do other irreversible preprocessing operations.

tokenizer = keras_nlp.tokenizers.WordPieceTokenizer(

Tokenize data

We preprocess the dataset by tokenizing and splitting it into features and labels.

# packer adds a start token
start_packer = keras_nlp.layers.StartEndPacker(

def preprocess(inputs):
    outputs = tokenizer(inputs)
    features = start_packer(outputs)
    labels = outputs
    return features, labels

# Tokenize and split into train and label sequences.
train_ds = raw_train_ds.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE).prefetch(
val_ds = raw_val_ds.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE).prefetch(

Build the model

We create our scaled down GPT model with the following layers:

inputs = keras.layers.Input(shape=(None,), dtype=tf.int32)
# Embedding.
embedding_layer = keras_nlp.layers.TokenAndPositionEmbedding(
x = embedding_layer(inputs)
# Transformer decoders.
for _ in range(NUM_LAYERS):
    decoder_layer = keras_nlp.layers.TransformerDecoder(
    x = decoder_layer(x)  # Giving one argument only skips cross-attention.
# Output.
outputs = keras.layers.Dense(VOCAB_SIZE)(x)
model = keras.Model(inputs=inputs, outputs=outputs)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
perplexity = keras_nlp.metrics.Perplexity(from_logits=True, mask_token_id=0)
model.compile(optimizer="adam", loss=loss_fn, metrics=[perplexity])

Let's take a look at our model summary - a large majority of the parameters are in the token_and_position_embedding and the output dense layer! This means that the vocabulary size (VOCAB_SIZE) has a large affect on the size of the model, while the number of Transformer decoder layers (NUM_LAYERS) doesn't affect it as much.

Model: "model"
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, None)]            0         

 token_and_position_embeddin  (None, None, 256)        1312768   
 g (TokenAndPositionEmbeddin                                     

 transformer_decoder (Transf  (None, None, 256)        394749    

 transformer_decoder_1 (Tran  (None, None, 256)        394749    

 dense (Dense)               (None, None, 5000)        1285000   

Total params: 3,387,266
Trainable params: 3,387,266
Non-trainable params: 0


Now that we have our model, let's train it with the fit() method.

model.fit(train_ds, validation_data=val_ds, verbose=2, epochs=EPOCHS)
Epoch 1/6
3169/3169 - 220s - loss: 4.5285 - perplexity: 98.5510 - val_loss: 4.0127 - val_perplexity: 63.4425 - 220s/epoch - 69ms/step
Epoch 2/6
3169/3169 - 219s - loss: 4.0143 - perplexity: 58.5463 - val_loss: 3.8603 - val_perplexity: 54.1763 - 219s/epoch - 69ms/step
Epoch 3/6
3169/3169 - 220s - loss: 3.8997 - perplexity: 52.1239 - val_loss: 3.8035 - val_perplexity: 51.1345 - 220s/epoch - 69ms/step
Epoch 4/6
3169/3169 - 219s - loss: 3.8381 - perplexity: 48.9710 - val_loss: 3.7728 - val_perplexity: 49.3502 - 219s/epoch - 69ms/step
Epoch 5/6
3169/3169 - 220s - loss: 3.7946 - perplexity: 46.8604 - val_loss: 3.7239 - val_perplexity: 46.9923 - 220s/epoch - 69ms/step
Epoch 6/6
3169/3169 - 219s - loss: 3.7634 - perplexity: 45.3980 - val_loss: 3.7166 - val_perplexity: 46.7066 - 219s/epoch - 69ms/step

<keras.callbacks.History at 0x7f74b1d543d0>


With our trained model, we can test it out to gauge it's performance. Since this model is built with a "[BOS]" token, we can have an empty starting prompt for text generation.

# Unpadded bos token.
prompt_tokens = tf.convert_to_tensor([tokenizer.token_to_id("[BOS]")])

We will use the keras_nlp.utils module for inference. Every text generation utility requires a token_logits_fn() wrapper around the model. This wrapper takes in an unpadded token sequence, and requires the logits of the next token as the output.

def token_logits_fn(inputs):
    cur_len = inputs.shape[1]
    output = model(inputs)
    return output[:, cur_len - 1, :]  # return next token logits

Creating the wrapper function is the most complex part of using these functions. Now that it's done, let's test out the different utilties, starting with greedy search.

We greedily pick the most probable token at each timestep. In other words, we get the argmax of the model output.

output_tokens = keras_nlp.utils.greedy_search(
txt = tokenizer.detokenize(output_tokens)
print(f"Greedy search generated text: \n{txt}\n")
Greedy search generated text: 
b'[BOS] " i have no doubt that , " the captain said , " and i have no doubt that the captain of the united states will be a very different from the english . the captain has been a very good sailor , and he has been a sailor , and he has been a sailor , and he has been a sailor , and he has been a sailor , and he has been a sailor , and'

As you can see, greedy search starts out making some sense, but quickly starts repeating itself. This is a common problem with text generation that can be fixed by some of the probabilistic text generation utilities shown later on!

At a high-level, beam search keeps track of the num_beams most probable sequences at each timestep, and predicts the best next token from all sequences. It is an improvement over greedy search since it stores more possibilities. However, it is less efficient than greedy search since it has to compute and store multiple potential sequences.

Note: beam search with num_beams=1 is identical to greedy search.

output_tokens = keras_nlp.utils.beam_search(
txt = tokenizer.detokenize(output_tokens)
print(f"Beam search generated text: \n{txt}\n")
Beam search generated text: 
b'[BOS] " i don \' t suppose that , " the captain said , with a smile . " i \' ll tell you what i \' ll have to do . i \' ll tell you what i \' ll do . i \' ll tell you about it . i \' ll tell you what i \' ll do . i \' ll tell you about it . i \' ll tell you about it . i \''

Similar to greedy search, beam search quickly starts repeating itself, since it is still a deterministic method.

Random search is our first probabilistic method. At each time step, it samples the next token using the softmax probabilities provided by the model.

output_tokens = keras_nlp.utils.random_search(
txt = tokenizer.detokenize(output_tokens)
print(f"Random search generated text: \n{txt}\n")
Random search generated text: 
b'[BOS] he described it to him that he was trying to do all this morning at the time and made him look quite pleased . " i know now that he mentioned his name to our men . i know they have crossed to their homes , and obtained anything like that to hear someone else in being signed by his own as they take the law - house which was still crowded by the search . to my'

Voila, no repetitions! However, with random search, we may see some nonsensical words appearing since any word in the vocabulary has a chance of appearing with this sampling method. This is fixed by our next search utility, top-k search.

Similar to random search, we sample the next token from the probability distribution provided by the model. The only difference is that here, we select out the top k most probable tokens, and distribute the probabiltiy mass over them before sampling. This way, we won't be sampling from low probability tokens, and hence we would have less nonsensical words!

output_tokens = keras_nlp.utils.top_k_search(
txt = tokenizer.detokenize(output_tokens)
print(f"Top-K search generated text: \n{txt}\n")
Top-K search generated text: 
b'[BOS] " you have got into our hands when he was out , and i thought of you . he had a great pleasure and happiness for his sake . you are very fond of his professication , though he is not a boy of sixteen ; but he will be glad that he is not to have a good time in this country for his services . the young man and his wife ,'

Even with the top-k search, there is something to improve upon. With top-k search, the number k is fixed, which means it selects the same number of tokens for any probability distribution. Consider two scenarios, one where the probability mass is concentrated over 2 words and another where the probability mass is evenly concentrated across 10. Should we choose k=2 or k=10? There is not a one size fits all k here.

This is where top-p search comes in! Instead of choosing a k, we choose a probability p that we want the probabilities of the top tokens to sum up to. This way, we can dynamically adjust the k based on the probability distribution. By setting p=0.9, if 90% of the probability mass is concentrated on the top 2 tokens, we can filter out the top 2 tokens to sample from. If instead the 90% is distributed over 10 tokens, it will similarly filter out the top 10 tokens to sample from.

output_tokens = keras_nlp.utils.top_p_search(
txt = tokenizer.detokenize(output_tokens)
print(f"Top-P search generated text: \n{txt}\n")
Top-P search generated text: 
b'[BOS] at the end of this the two sisters were so startled that the dog would be in the tree . when they were gone , they were caught in a hint of their clothes , they did not want to stay until they were out of sight of the door , but it was not a little black dog . then they went on their way home , and they went off to the top of the tree'

Using callbacks for text generation

We can also wrap the utilities in a callback, which allows you to print out a prediction sequence for every epoch of the model! Here is an example of a callback for top-k search:

class TopKTextGenerator(keras.callbacks.Callback):
    """A callback to generate text from a trained model using top-k."""

    def __init__(self, k):
        self.k = k

    def on_epoch_end(self, epoch, logs=None):
        output_tokens = keras_nlp.utils.top_k_search(
        txt = tokenizer.detokenize(output_tokens)
        print(f"Top-K search generated text: \n{txt}\n")

text_generation_callback = TopKTextGenerator(k=10)
# Dummy training loop to demonstrate callback.
model.fit(train_ds.take(1), verbose=2, epochs=2, callbacks=[text_generation_callback])
Epoch 1/2
Top-K search generated text: 
b'[BOS] as the corrals were very different . the prominent features of this province is in a state of great importance and concessions , the spacious state of affairs of promineering state in the extreme , interpreter , to collect the establishment , in the most economical composition'
1/1 - 10s - loss: 3.8154 - perplexity: 46.5370 - 10s/epoch - 10s/step
Epoch 2/2
Top-K search generated text: 
b'[BOS] " we will be a man of great value to the condema - cove . we will not be in a very short time , but we have some of our men . we must be in our hands , as it is , as the province , and we will find the way that a large number is to be made . there is an indian canoe on the shore . if'
1/1 - 11s - loss: 3.6902 - perplexity: 42.6255 - 11s/epoch - 11s/step

<keras.callbacks.History at 0x7f74306a9310>


To recap, in this example, we use KerasNLP layers to train a sub-word vocabulary, tokenize training data, create a miniature GPT model, and perform inference with the text generation library.

If you would like to understand how Transformers work, or learn more about training the full GPT model, here are some further readings: