Gemma4VisionEncoder classkeras_hub.models.Gemma4VisionEncoder(
image_size,
patch_size,
num_heads,
hidden_dim,
num_layers,
intermediate_dim,
head_dim,
output_dim,
num_key_value_heads=None,
pool_size=3,
position_embedding_size=1024,
rope_max_wavelength=100.0,
layer_norm_epsilon=1e-06,
dropout=0,
use_clipped_linears=True,
standardize=False,
dtype=None,
**kwargs
)
Vision Transformer (ViT) encoder for Gemma4.
This encoder is architecturally different from the Gemma3 vision encoder. Rather than using a separate CLIP-style ViT, Gemma4 uses the same transformer block style as the text decoder (with 4 norms per block, Q/K/V normalization) with bidirectional (non-causal) attention.
Position information is encoded via two separate learnable position- embedding tables — one for the x-axis and one for the y-axis — whose outputs are added to the patch features. This 2D decomposed embedding can represent any image height and width independently.
After encoding, the patch sequence is spatially pooled down to a fixed
output_dim-wide representation and then projected into the text hidden
dimension.
Arguments
patch_size * pool_size.hidden_dim).num_heads (MHA).(image_size // patch_size // pool_size) ** 2. Must evenly divide
image_size // patch_size. Defaults to 3.image_size // patch_size. Defaults to 1024.1e-6.0.None (uses Keras global policy).Example
import numpy as np
vision_encoder = keras_hub.models.Gemma4VisionEncoder(
image_size=768,
patch_size=16,
num_heads=12,
hidden_dim=768,
num_layers=12,
intermediate_dim=3072,
head_dim=64,
output_dim=2304,
pool_size=3,
)
# pixel_values: (batch, num_images, num_patches, patch_dim)
# For a 768x768 image with patch_size=16: num_patches=48*48=2304,
# patch_dim=16*16*3=768.
pixel_values = np.ones((1, 1, 2304, 768), dtype="float32")
# pixel_position_ids: (batch, num_images, num_patches, 2)
pixel_position_ids = np.zeros((1, 1, 2304, 2), dtype="int32")
output = vision_encoder(
{"pixel_values": pixel_values, "pixel_position_ids": pixel_position_ids}
)
# output.shape == (1, 1, 256, 2304)
# (batch, num_images, pooled_patches, output_dim)