Code examples / Quick Keras Recipes / Float8 training and inference with a simple Transformer model

Float8 training and inference with a simple Transformer model

Author: Hongyu Chiu
Date created: 2024/05/14
Last modified: 2024/05/14
Description: Train a simple Transformer model with the float8 quantization.

ⓘ This example uses Keras 3

View in Colab GitHub source


As the number of parameters in Transformer models continues to grow, training and inference become highly memory and compute-intensive. Therefore, 8-bit floating point (FP8) was introduced, offering improved performance over 16-bit floating point with nearly no degradation in accuracy.

In detail, there are two distinct types of FP8: E4M3 and E5M2, useful in different parts of training.

  • E4M3: It consists of 1 sign bit, 4 exponent bits and 3 bits of mantissa. It can store values up to +/-448 and nan.
  • E5M2: It consists of 1 sign bit, 5 exponent bits and 2 bits of mantissa. It can store values up to +/-57344, +/-inf and nan. The tradeoff of theincreased dynamic range is lower precision of the stored values.

Typically, E4M3 is best used during the forward pass because activations and weights require more precision. In the backward pass, however, E5M2 is utilized because gradients are less susceptible to the loss of precision but require higher dynamic range.

It is worth noting that FP8 inference deployment is greatly simplified, as inference and training use the same datatype. This is in contrast to INT8 inference with networks trained in 32- or 16-bit floating point, which require post-training quantization (PTQ) calibration and even quantization-aware training (QAT) in order to maintain model accuracy.

In this example, we will build a simple Transformer model and train it with both FP16 and FP8 precision. You will observe that the accuracy doesn't decrease with lower precision.

Note: You will need a decent GPU with FP8 Tensor Cores support for the expected performance improvement.


We will use KerasNLP library to simplify the model implementation. Additionally, use mixed precision training to reduce the training time.

Note: The dependency on TensorFlow is only required for data processing.

!pip install -q --upgrade keras-nlp
!pip install -q --upgrade keras  # Upgrade to Keras 3.
import os

os.environ["KERAS_BACKEND"] = "jax"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import re

import keras
import keras_nlp
import tensorflow as tf


Define some hyperparameters.

    hidden_dim=32,  # Hidden size for each token
    num_heads=2,  # Number of attention heads
    intermediate_dim=32,  # Intermediate size in feedforward network
    dropout=0.1,  # Dropout rate


First, let's download the IMDB dataset and extract it.

!mkdir -p datasets
!wget -q -O datasets/aclImdb_v1.tar.gz
!mkdir -p datasets/aclImdb
!tar -xzf datasets/aclImdb_v1.tar.gz -C datasets
!rm -rf datasets/aclImdb/train/unsup

We'll use the keras.utils.text_dataset_from_directory utility to generate our labelled dataset from text files.

train_ds = keras.utils.text_dataset_from_directory(
val_ds = keras.utils.text_dataset_from_directory(
test_ds = keras.utils.text_dataset_from_directory(
    "datasets/aclImdb/test", batch_size=BATCH_SIZE
Found 25000 files belonging to 2 classes.

Using 20000 files for training.

Found 25000 files belonging to 2 classes.

Using 5000 files for validation.

Found 25000 files belonging to 2 classes.

We will now convert the text to lowercase.

train_ds = x, y: (tf.strings.lower(x), y))
val_ds = x, y: (tf.strings.lower(x), y))
test_ds = x, y: (tf.strings.lower(x), y))

Let's print a few samples.

for text_batch, label_batch in train_ds.take(1):
    for i in range(3):
        print(f"Text: {text_batch.numpy()[i]}")
        print(f"Label: {label_batch.numpy()[i]}")
Text: b'"pandemonium" is a horror movie spoof that comes off more stupid than funny. believe me when i tell you, i love comedies. especially comedy spoofs. "airplane", "the naked gun" trilogy, "blazing saddles", "high anxiety", and "spaceballs" are some of my favorite comedies that spoof a particular genre. "pandemonium" is not up there with those films. most of the scenes in this movie had me sitting there in stunned silence because the movie wasn\'t all that funny. there are a few laughs in the film, but when you watch a comedy, you expect to laugh a lot more than a few times and that\'s all this film has going for it. geez, "scream" had more laughs than this film and that was more of a horror film. how bizarre is that?<br /><br />*1/2 (out of four)'
Label: 0
Text: b"david mamet is a very interesting and a very un-equal director. his first movie 'house of games' was the one i liked best, and it set a series of films with characters whose perspective of life changes as they get into complicated situations, and so does the perspective of the viewer.<br /><br />so is 'homicide' which from the title tries to set the mind of the viewer to the usual crime drama. the principal characters are two cops, one jewish and one irish who deal with a racially charged area. the murder of an old jewish shop owner who proves to be an ancient veteran of the israeli independence war triggers the jewish identity in the mind and heart of the jewish detective.<br /><br />this is were the flaws of the film are the more obvious. the process of awakening is theatrical and hard to believe, the group of jewish militants is operatic, and the way the detective eventually walks to the final violent confrontation is pathetic. the end of the film itself is mamet-like smart, but disappoints from a human emotional perspective.<br /><br />joe mantegna and william macy give strong performances, but the flaws of the story are too evident to be easily compensated."
Label: 0
Text: b'great documentary about the lives of ny firefighters during the worst terrorist attack of all time.. that reason alone is why this should be a must see collectors item.. what shocked me was not only the attacks, but the"high fat diet" and physical appearance of some of these firefighters. i think a lot of doctors would agree with me that,in the physical shape they were in, some of these firefighters would not of made it to the 79th floor carrying over 60 lbs of gear. having said that i now have a greater respect for firefighters and i realize becoming a firefighter is a life altering job. the french have a history of making great documentary\'s and that is what this is, a great documentary.....'
Label: 1

Tokenizing the data

We'll be using the keras_nlp.tokenizers.WordPieceTokenizer layer to tokenize the text. keras_nlp.tokenizers.WordPieceTokenizer takes a WordPiece vocabulary and has functions for tokenizing the text, and detokenizing sequences of tokens.

Before we define the tokenizer, we first need to train it on the dataset we have. The WordPiece tokenization algorithm is a subword tokenization algorithm; training it on a corpus gives us a vocabulary of subwords. A subword tokenizer is a compromise between word tokenizers (word tokenizers need very large vocabularies for good coverage of input words), and character tokenizers (characters don't really encode meaning like words do). Luckily, KerasNLP makes it very simple to train WordPiece on a corpus with the keras_nlp.tokenizers.compute_word_piece_vocabulary utility.

def train_word_piece(ds, vocab_size, reserved_tokens):
    word_piece_ds = ds.unbatch().map(lambda x, y: x)
    vocab = keras_nlp.tokenizers.compute_word_piece_vocabulary(
    return vocab

Every vocabulary has a few special, reserved tokens. We have two such tokens:

  • "[PAD]" - Padding token. Padding tokens are appended to the input sequence length when the input sequence length is shorter than the maximum sequence length.
  • "[UNK]" - Unknown token.
reserved_tokens = ["[PAD]", "[UNK]"]
train_sentences = [element[0] for element in train_ds]
vocab = train_word_piece(train_ds, VOCABULARY_SIZE, reserved_tokens)

Let's see some tokens!

print("Tokens: ", vocab[100:110])
Tokens:  ['à', 'á', 'â', 'ã', 'ä', 'å', 'æ', 'ç', 'è', 'é']

Now, let's define the tokenizer. We will configure the tokenizer with the the vocabularies trained above. We will define a maximum sequence length so that all sequences are padded to the same length, if the length of the sequence is less than the specified sequence length. Otherwise, the sequence is truncated.

tokenizer = keras_nlp.tokenizers.WordPieceTokenizer(

Let's try and tokenize a sample from our dataset! To verify whether the text has been tokenized correctly, we can also detokenize the list of tokens back to the original text.

input_sentence_ex = train_ds.take(1).get_single_element()[0][0]
input_tokens_ex = tokenizer(input_sentence_ex)

print("Sentence: ", input_sentence_ex)
print("Tokens: ", input_tokens_ex)
print("Recovered text after detokenizing: ", tokenizer.detokenize(input_tokens_ex))
Sentence:  tf.Tensor(b'great movie - especially the music - etta james - "at last". this speaks volumes when you have finally found that special someone.', shape=(), dtype=string)

[  218   150    14   393   137   356    14  4917  2941   719    14     3
   164   370     3    15   145  2705 11670   186   155   160   557   391
   146   452   416    15     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0]
Recovered text after detokenizing:  tf.Tensor(b'great movie - especially the music - etta james - " at last " . this speaks volumes when you have finally found that special someone . [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]', shape=(), dtype=string)

Formatting the dataset

Next, we'll format our datasets in the form that will be fed to the models. We need to tokenize the text.

def format_dataset(sentence, label):
    sentence = tokenizer(sentence)
    return ({"input_ids": sentence}, label)

def make_dataset(dataset):
    dataset =,
    return dataset.shuffle(512).prefetch(

train_ds = make_dataset(train_ds)
val_ds = make_dataset(val_ds)
test_ds = make_dataset(test_ds)


Let's build a simple Transformer model. We will use TokenAndPositionEmbedding and TransformerDecoder from KerasNLP library. TokenAndPositionEmbedding represents words and their order in a sentence, while TransformerDecoder outputs one vector for each time step of our input sequence. Here, we take the mean across all time steps and use a feedforward network on top of it to classify text.

def build_model(
    token_id_input = keras.layers.Input(shape=(None,), dtype="int32", name="input_ids")
    x = keras_nlp.layers.TokenAndPositionEmbedding(
    x = keras.layers.Dropout(rate=dropout)(x)
    x = keras_nlp.layers.TransformerDecoder(
    x = keras.layers.GlobalAveragePooling1D()(x)
    x = keras.layers.Dropout(dropout)(x)
    x = keras.layers.Dense(intermediate_dim, activation="relu")(x)
    x = keras.layers.Dropout(dropout)(x)
    outputs = keras.layers.Dense(1, activation="sigmoid")(x)
    return keras.Model(inputs=token_id_input, outputs=outputs)

Training and evaluating our model

First, we train and evaluate the model with mixed precision ("mixed_bfloat16"). Afterward, we compare the results with FP8 training/inference.

model = build_model(**MODEL_KWARGS)
history =, epochs=EPOCHS, validation_data=val_ds)
result = model.evaluate(test_ds)
print(f"Accuracy (mixed_bfloat16): {result[1]:.2%}")
Model: "functional_1"
┃ Layer (type)                     Output Shape                  Param # ┃
│ input_ids (InputLayer)          │ (None, None)           │             0 │
│ token_and_position_embedding    │ (None, None, 32)       │       646,400 │
│ (TokenAndPositionEmbedding)     │                        │               │
│ dropout (Dropout)               │ (None, None, 32)       │             0 │
│ transformer_decoder             │ (None, None, 32)       │         6,464 │
│ (TransformerDecoder)            │                        │               │
│ global_average_pooling1d        │ (None, 32)             │             0 │
│ (GlobalAveragePooling1D)        │                        │               │
│ dropout_2 (Dropout)             │ (None, 32)             │             0 │
│ dense (Dense)                   │ (None, 32)             │         1,056 │
│ dropout_3 (Dropout)             │ (None, 32)             │             0 │
│ dense_1 (Dense)                 │ (None, 1)              │            33 │
 Total params: 653,953 (2.49 MB)
 Trainable params: 653,953 (2.49 MB)
 Non-trainable params: 0 (0.00 B)
Accuracy (mixed_bfloat16): 75.56%

We can enable FP8 training/inference with a one-line API: model.quantize("float8").

model = build_model(**MODEL_KWARGS)

To inspect that FP8 training takes place, we can print out some variables related to FP8 training:

  • *_scale: The scaling factor that shift the distribution of inputs, weights and gradients into the representable range of FP8. Defaults to 1.0
  • *_amax_history: The amax history window used for scaling factor computation. Defaults to 0.0 with the length of 1024.
pattern = r"(transformer).+(multi_head).+(query).+(scale|amax_history)"
for v in model.trainable_variables:
    if re.findall(pattern, v.path):

The dtype policies of FP8 layers have also been modified.

for layer in model._flatten_layers(recursive=True):
    if "float8" in str(layer.dtype_policy):
        print(f"{}: {layer.dtype_policy}")
feedforward_output_dense: <QuantizedFloat8DTypePolicy "float8_from_mixed_bfloat16">
feedforward_intermediate_dense: <QuantizedFloat8DTypePolicy "float8_from_mixed_bfloat16">
attention_output: <QuantizedFloat8DTypePolicy "float8_from_mixed_bfloat16">
value: <QuantizedFloat8DTypePolicy "float8_from_mixed_bfloat16">
key: <QuantizedFloat8DTypePolicy "float8_from_mixed_bfloat16">
query: <QuantizedFloat8DTypePolicy "float8_from_mixed_bfloat16">
dense_2: <QuantizedFloat8DTypePolicy "float8_from_mixed_bfloat16">
dense_3: <QuantizedFloat8DTypePolicy "float8_from_mixed_bfloat16">

Let's train the model and see the results. We can verify that the accuracy doesn't decrease with FP8 training that the variables containing FP8 information change after fitting.

history =, epochs=EPOCHS, validation_data=val_ds)
result = model.evaluate(test_ds)
print(f"Accuracy (float8): {result[1]:.2%}")

for v in model.trainable_variables:
    if re.findall(pattern, v.path):
Accuracy (float8): 74.16%


  • The improvements in training speed are relatively small if the model is not sufficiently large. The recommendation is to train with a model containing parameters >5B.
  • You will need hardware such as NVIDIA H100 that supports FP8 Tensor Cores to gain the speedups.