Author: Jyotinder Singh
Date created: 2025/10/09
Last modified: 2025/10/09
Description: Overview of quantization in Keras (int8, float8, int4, GPTQ).
Modern large models are often memory- and bandwidth-bound: most inference time is spent moving tensors between memory and compute units rather than doing math. Quantization reduces the number of bits used to represent the model's weights and (optionally) activations, which:
Keras provides first-class post-training quantization (PTQ) workflows which support pretrained models and expose a uniform API at both the model and layer level.
At a high level, Keras supports:
int4
, int8
, and float8
.x
to integers q
using a scale (and optionally a zero-point). Symmetric schemes use only a scale.Keras currently focuses on the following numeric formats. Each mode can be applied selectively to layers or to the whole model via the same API.
int8
(8-bit integer): joint weight + activation PTQ.
float8
(FP8: E4M3 / E5M2 variants): Low-precision floating-point useful for training and inference on FP8-capable hardware.
int4
: Ultra-low-bit weights for aggressive compression; activations remain in higher precision (int8).
GPTQ
(weight-only 2/3/4/8 bits): Second-order, post-training method minimizing layer output error.
int4
, int8
PTQ path, activation scales are computed on-the-fly at runtime (per tensor and per batch) using an AbsMax estimator. This avoids maintaining a separate, fixed set of activation scales from a calibration pass and adapts to varying input ranges.int4
, Keras packs signed 4-bit values (range = [-8, 7]) and stores per-channel scales such as kernel_scale
. Dequantization happens on the fly, and matmuls use 8-bit (unpacked) kernels.int4
/ int8
/ float8
uses AbsMax calibration by default (range set by the maximum absolute value observed). Alternative calibration methods (e.g., percentile) may be added in future releases.Quantization is applied explicitly after layers or models are built. The API is designed to be predictable: you call quantize, the graph is rewritten, the weights are replaced, and you can immediately run inference or save the model.
Typical workflow:
build()
or a forward pass has materialized weights.model.quantize("<mode>")
or layer.quantize("<mode>")
with "int8"
, "int4"
, "float8"
, or "gptq"
(weight-only).model.save(...)
. Quantization state (packed weights, scales, metadata) is preserved on save/load.import keras
import numpy as np
# Sample training data.
x_train = keras.ops.array(np.random.rand(100, 10))
y_train = keras.ops.array(np.random.rand(100, 1))
# Build the model.
model = keras.Sequential(
[
keras.Input(shape=(10,)),
keras.layers.Dense(32, activation="relu"),
keras.layers.Dense(1),
]
)
# Compile and fit the model.
model.compile(optimizer="adam", loss="mean_squared_error")
model.fit(x_train, y_train, epochs=1, verbose=0)
# Quantize the model.
model.quantize("int8")
What this does: Quantizes the weights of the supported layers, and re-wires their forward paths to be compatible with the quantized kernels and quantization scales.
Note: Throughput gains depend on backend/hardware kernels; in cases where kernels fall back to dequantized matmul, you still get memory savings but smaller speedups.
The Keras quantization framework allows you to quantize each layer separately, without having to quantize the entire model using the same unified API.
from keras import layers
input_shape = (10,)
layer = layers.Dense(32, activation="relu")
layer.build(input_shape)
layer.quantize("int4") # Or "int8", "float8", etc.
Keras supports the following core layers in its quantization framework:
Dense
EinsumDense
Embedding
ReversibleEmbedding
(available in KerasHub)Any composite layers that are built from the above (for example, MultiHeadAttention
, GroupedQueryAttention
, feed-forward blocks in Transformers) inherit quantization support by construction. This covers the majority of modern encoder-only and decoder-only Transformer architectures.
Since all KerasHub models subclass keras.Model
, they automatically support the model.quantize(...)
API. In practice, this means you can take a popular LLM preset, call a single function to obtain an int8/int4/GPTQ-quantized variant, and then save or serve it—without changing your training code.