Author: Jyotinder Singh
Date created: 2025/10/14
Last modified: 2025/10/14
Description: Complete guide to using INT4 quantization in Keras and KerasHub.
Quantization lowers the numerical precision of weights and activations to reduce memory use and often speed up inference, at the cost of a small accuracy drop. INT4 post-training quantization (PTQ) stores model weights in 4-bit signed integers and dynamically quantizes activations to 8-bit at runtime (a W4A8 scheme). Compared with FP32 this can shrink weight storage ~8x (2x vs INT8) while retaining acceptable accuracy for many encoder models and some decoder models. Compute still leverages widely available NVIDIA INT8 Tensor Cores.
4-bit is a more aggressive compression than 8-bit and may induce larger quality regressions, especially for large autoregressive language models.
Quantization maps real values to 4-bit integers with a scale:
[-8, 7]
(4 bits) and packed two-per-byte.(input_scale * per_channel_kernel_scale)
.This mirrors the INT8 path described in the INT8 guide with some added unpack overhead for stronger compression.
This guide shows how to use 4-bit (W4A8) post-training quantization in Keras:
Below we build a small functional model, capture a baseline output, quantize to INT4 in place, and compare outputs with an MSE metric. (For real evaluation use your validation metric.)
import numpy as np
import keras
from keras import layers
# Create a random number generator.
rng = np.random.default_rng()
# Create a simple functional model.
inputs = keras.Input(shape=(10,))
x = layers.Dense(32, activation="relu")(inputs)
outputs = layers.Dense(1, name="target")(x)
model = keras.Model(inputs, outputs)
# Baseline output with full-precision weights.
x_eval = rng.random((32, 10)).astype("float32")
y_fp32 = model(x_eval)
# Quantize the model in-place to INT4 (W4A8).
model.quantize("int4")
# Compare outputs (MSE).
y_int4 = model(x_eval)
mse = keras.ops.mean(keras.ops.square(y_fp32 - y_int4))
print("Full-Precision vs INT4 MSE:", float(mse))
Full-Precision vs INT4 MSE: 0.00028205406852066517
The INT4 quantized model usually produces outputs close enough for many downstream tasks. Expect larger deltas than INT8, so always validate on your own data.
You can use standard Keras saving / loading APIs. Quantization metadata (including scales and packed weights) is preserved.
# Save the quantized model and reload to verify round-trip.
model.save("int4.keras")
int4_reloaded = keras.saving.load_model("int4.keras")
y_int4_reloaded = int4_reloaded(x_eval)
# Compare outputs (MSE).
roundtrip_mse = keras.ops.mean(keras.ops.square(y_fp32 - y_int4_reloaded))
print("MSE (INT4 vs reloaded INT4):", float(roundtrip_mse))
MSE (INT4 vs reloaded INT4): 0.00028205406852066517
All KerasHub models support the .quantize(...)
API for post-training quantization,
and follow the same workflow as above.
In this example, we will:
import os
from keras_hub.models import Gemma3CausalLM
# Load a Gemma3 preset from KerasHub.
gemma3 = Gemma3CausalLM.from_preset("gemma3_1b")
# Generate with full-precision weights.
fp_output = gemma3.generate("Keras is a", max_length=30)
print("Full-precision output:", fp_output)
# Save the full-precision model to a preset.
gemma3.save_to_preset("gemma3_fp32")
# Quantize to INT4.
gemma3.quantize("int4")
# Generate with INT4 weights.
output = gemma3.generate("Keras is a", max_length=30)
print("Quantized output:", output)
# Save INT4 model to a new preset.
gemma3.save_to_preset("gemma3_int4")
# Reload and compare outputs
gemma3_int4 = Gemma3CausalLM.from_preset("gemma3_int4")
output = gemma3_int4.generate("Keras is a", max_length=30)
print("Quantized reloaded output:", output)
# Compute storage savings
def bytes_to_mib(n):
return n / (1024**2)
gemma_fp32_size = os.path.getsize("gemma3_fp32/model.weights.h5")
gemma_int4_size = os.path.getsize("gemma3_int4/model.weights.h5")
gemma_reduction = 100.0 * (1.0 - (gemma_int4_size / max(gemma_fp32_size, 1)))
print(f"Gemma3: FP32 file size: {bytes_to_mib(gemma_fp32_size):.2f} MiB")
print(f"Gemma3: INT4 file size: {bytes_to_mib(gemma_int4_size):.2f} MiB")
print(f"Gemma3: Size reduction: {gemma_reduction:.1f}%")
Full-precision output: Keras is a deep learning library for Python. It is a high-level API for neural networks. It is a Python library for deep learning
Quantized output: Keras is a python-based, open-source, and free-to-use, open-source, and a popular, and a
Quantized reloaded output: Keras is a python-based, open-source, and free-to-use, open-source, and a popular, and a
Gemma3: FP32 file size: 3815.32 MiB
Gemma3: INT4 file size: 1488.10 MiB
Gemma3: Size reduction: 61.0%
Micro-benchmarks collected on a single NVIDIA L4 (22.5 GB). Baselines are FP32.
Metric | FP32 | INT4 | Change |
---|---|---|---|
Accuracy (↑) | 91.06% | 90.14% | -0.92pp |
Model Size (MB, ↓) | 255.86 | 159.49 | -37.67% |
Peak GPU Memory (MiB, ↓) | 1554.00 | 1243.26 | -20.00% |
Latency (ms/sample, ↓) | 6.43 | 5.73 | -10.83% |
Throughput (samples/s, ↑) | 155.60 | 174.50 | +12.15% |
Analysis: Accuracy drop is modest (<1pp) with notable speed and memory gains; encoder-only models tend to retain fidelity under heavier weight compression.
Metric | FP32 | INT4 | Change |
---|---|---|---|
Perplexity (↓) | 7.44 | 9.98 | +34.15% |
Model Size (GB, ↓) | 4.8884 | 0.9526 | -80.51% |
Peak GPU Memory (MiB, ↓) | 8021.12 | 5483.46 | -31.64% |
First Token Latency (ms, ↓) | 128.87 | 122.50 | -4.95% |
Sequence Latency (ms, ↓) | 338.29 | 181.93 | -46.22% |
Token Throughput (tokens/s, ↑) | 174.41 | 256.96 | +47.33% |
Analysis: INT4 gives large size (-80%) and memory (-32%) reductions. Perplexity increases (expected for aggressive compression) yet sequence latency drops and throughput rises ~50%.
Metric | FP32 | INT4 | Change |
---|---|---|---|
Perplexity (↓) | 6.17 | 10.46 | +69.61% |
Model Size (GB, ↓) | 3.7303 | 1.4576 | -60.92% |
Peak GPU Memory (MiB, ↓) | 6844.67 | 5008.14 | -26.83% |
First Token Latency (ms, ↓) | 57.42 | 64.21 | +11.83% |
Sequence Latency (ms, ↓) | 239.78 | 161.18 | -32.78% |
Token Throughput (tokens/s, ↑) | 246.06 | 366.05 | +48.76% |
Analysis: INT4 gives large size (-61%) and memory (-27%) reductions. Perplexity increases (expected for aggressive compression) yet sequence latency drops and throughput rises ~50%.
Metric | FP32 | INT4 | Change |
---|---|---|---|
Perplexity (↓) | 6.38 | 14.16 | +121.78% |
Model Size (GB, ↓) | 5.5890 | 2.4186 | -56.73% |
Peak GPU Memory (MiB, ↓) | 9509.49 | 6810.26 | -28.38% |
First Token Latency (ms, ↓) | 209.41 | 219.09 | +4.62% |
Sequence Latency (ms, ↓) | 322.33 | 262.15 | -18.67% |
Token Throughput (tokens/s, ↑) | 183.82 | 230.78 | +25.55% |
Analysis: INT4 gives large size (-57%) and memory (-28%) reductions. Perplexity increases (expected for aggressive compression) yet sequence latency drops and throughput rises ~25%.
Metric | FP32 | INT4 | Change |
---|---|---|---|
Perplexity (↓) | 13.85 | 21.02 | +51.79% |
Model Size (MB, ↓) | 468.3 | 284.0 | -39.37% |
Peak GPU Memory (MiB, ↓) | 1007.23 | 659.28 | -34.54% |
First Token Latency (ms/sample, ↓) | 95.79 | 97.87 | +2.18% |
Sequence Latency (ms/sample, ↓) | 60.35 | 54.64 | -9.46% |
Throughput (samples/s, ↑) | 973.41 | 1075.15 | +10.45% |
Analysis: INT4 gives large size (-39%) and memory (-35%) reductions. Perplexity increases (expected for aggressive compression) yet sequence latency drops and throughput rises ~10%.
Goal / Constraint | Prefer INT8 | Prefer INT4 (W4A8) |
---|---|---|
Minimal accuracy drop critical | ✔︎ | |
Maximum compression (disk / RAM) | ✔︎ | |
Bandwidth-bound inference | Possible | Often better |
Decoder LLM | ✔︎ | Try with eval first |
Encoder / classification models | ✔︎ | ✔︎ |
Available kernels / tooling maturity | ✔︎ | Emerging |
build()
or a forward pass).