Author: Jyotinder Singh
Date created: 2025/10/14
Last modified: 2025/10/14
Description: Complete guide to using INT8 quantization in Keras and KerasHub.
Quantization lowers the numerical precision of weights and activations to reduce memory use
and often speed up inference, at the cost of a small accuracy drop. Moving from float32
to
float16
halves the memory requirements; float32
to INT8 is ~4x smaller (and ~2x vs
float16
). On hardware with low-precision kernels (e.g., NVIDIA Tensor Cores), this can also
improve throughput and latency. Actual gains depend on your backend and device.
Quantization maps real values to 8-bit integers with a scale:
[-128, 127]
(256 levels).w
:a_max = max(abs(w))
.s = (2 * a_max) / 256
.q = clip(round(w / s), -128, 127)
(stored as INT8) and keep s
.q
and s
to reconstruct effective weights on the fly
(w ≈ s · q
) or folds s
into the matmul/conv for efficiency.float32
, improving cache behavior and reducing memory stalls;
this often helps more than increasing raw FLOPs.float16
; INT8 may introduce a modest
drop (often ~1-5% depending on task/model/data). Always validate on your own dataset.This guide shows how to use 8-bit integer post-training quantization (PTQ) in Keras:
We build a small functional model, capture a baseline output, quantize to INT8 in-place, and then compare outputs with an MSE metric.
import os
import numpy as np
import keras
from keras import layers
# Create a random number generator.
rng = np.random.default_rng()
# Create a simple functional model.
inputs = keras.Input(shape=(10,))
x = layers.Dense(32, activation="relu")(inputs)
outputs = layers.Dense(1, name="target")(x)
model = keras.Model(inputs, outputs)
# Compile and train briefly to materialize meaningful weights.
model.compile(optimizer="adam", loss="mse")
x_train = rng.random((256, 10)).astype("float32")
y_train = rng.random((256, 1)).astype("float32")
model.fit(x_train, y_train, epochs=1, batch_size=32, verbose=0)
# Sample inputs for evaluation.
x_eval = rng.random((32, 10)).astype("float32")
# Baseline (FP) outputs.
y_fp32 = model(x_eval)
# Quantize the model in-place to INT8.
model.quantize("int8")
# INT8 outputs after quantization.
y_int8 = model(x_eval)
# Compute a simple MSE between FP and INT8 outputs.
mse = keras.ops.mean(keras.ops.square(y_fp32 - y_int8))
print("Full-Precision vs INT8 MSE:", float(mse))
Full-Precision vs INT8 MSE: 4.982496648153756e-06
It is evident that the INT8 quantized model produces outputs close to the original FP32 model, as indicated by the low MSE value.
You can use the standard Keras saving and loading APIs with quantized models. Quantization
is preserved when saving to .keras
and loading back.
# Save the quantized model and reload to verify round-trip.
model.save("int8.keras")
int8_reloaded = keras.saving.load_model("int8.keras")
y_int8_reloaded = int8_reloaded(x_eval)
roundtrip_mse = keras.ops.mean(keras.ops.square(y_int8 - y_int8_reloaded))
print("MSE (INT8 vs reloaded-INT8):", float(roundtrip_mse))
MSE (INT8 vs reloaded-INT8): 0.0
All KerasHub models support the .quantize(...)
API for post-training quantization,
and follow the same workflow as above.
In this example, we will:
from keras_hub.models import Gemma3CausalLM
# Load from Gemma3 preset
gemma3 = Gemma3CausalLM.from_preset("gemma3_1b")
# Generate text for a single prompt
output = gemma3.generate("Keras is a", max_length=50)
print("Full-precision output:", output)
# Save FP32 Gemma3 model for size comparison.
gemma3.save_to_preset("gemma3_fp32")
# Quantize in-place to INT8 and generate again
gemma3.quantize("int8")
output = gemma3.generate("Keras is a", max_length=50)
print("Quantized output:", output)
# Save INT8 Gemma3 model
gemma3.save_to_preset("gemma3_int8")
# Reload and compare outputs
gemma3_int8 = Gemma3CausalLM.from_preset("gemma3_int8")
output = gemma3_int8.generate("Keras is a", max_length=50)
print("Quantized reloaded output:", output)
# Compute storage savings
def bytes_to_mib(n):
return n / (1024**2)
gemma_fp32_size = os.path.getsize("gemma3_fp32/model.weights.h5")
gemma_int8_size = os.path.getsize("gemma3_int8/model.weights.h5")
gemma_reduction = 100.0 * (1.0 - (gemma_int8_size / max(gemma_fp32_size, 1)))
print(f"Gemma3: FP32 file size: {bytes_to_mib(gemma_fp32_size):.2f} MiB")
print(f"Gemma3: INT8 file size: {bytes_to_mib(gemma_int8_size):.2f} MiB")
print(f"Gemma3: Size reduction: {gemma_reduction:.1f}%")
Full-precision output: Keras is a deep learning library for Python. It is a high-level API for neural networks. It is a Python library for deep learning. It is a library for deep learning. It is a library for deep learning. It is a
Quantized output: Keras is a Python library for deep learning. It is a high-level API for building deep learning models. It is designed to be easy
Quantized reloaded output: Keras is a Python library for deep learning. It is a high-level API for building deep learning models. It is designed to be easy
Gemma3: FP32 file size: 3815.32 MiB
Gemma3: INT8 file size: 957.81 MiB
Gemma3: Size reduction: 74.9%
build()
or a forward pass).