Gemma4VisionEncoder model

[source]

Gemma4VisionEncoder class

keras_hub.models.Gemma4VisionEncoder(
    image_size,
    patch_size,
    num_heads,
    hidden_dim,
    num_layers,
    intermediate_dim,
    head_dim,
    output_dim,
    num_key_value_heads=None,
    pool_size=3,
    position_embedding_size=1024,
    rope_max_wavelength=100.0,
    layer_norm_epsilon=1e-06,
    dropout=0,
    use_clipped_linears=True,
    standardize=False,
    dtype=None,
    **kwargs
)

Vision Transformer (ViT) encoder for Gemma4.

This encoder is architecturally different from the Gemma3 vision encoder. Rather than using a separate CLIP-style ViT, Gemma4 uses the same transformer block style as the text decoder (with 4 norms per block, Q/K/V normalization) with bidirectional (non-causal) attention.

Position information is encoded via two separate learnable position- embedding tables — one for the x-axis and one for the y-axis — whose outputs are added to the patch features. This 2D decomposed embedding can represent any image height and width independently.

After encoding, the patch sequence is spatially pooled down to a fixed output_dim-wide representation and then projected into the text hidden dimension.

Arguments

  • image_size: int. The height/width of the (square) image. Must be divisible by patch_size * pool_size.
  • patch_size: int. Size of each square patch in pixels.
  • num_heads: int. Number of attention heads in each vision transformer layer.
  • hidden_dim: int. Hidden dimension of the vision transformer blocks.
  • num_layers: int. Number of transformer layers in the vision encoder.
  • intermediate_dim: int. Intermediate FFW dimension in each block.
  • head_dim: int. Dimension of each attention head.
  • output_dim: int. Dimension to project encoded patches to (should equal the text backbone's hidden_dim).
  • num_key_value_heads: int. For grouped-query attention. Defaults to num_heads (MHA).
  • pool_size: int. Spatial pooling factor applied after the transformer. The output sequence length equals (image_size // patch_size // pool_size) ** 2. Must evenly divide image_size // patch_size. Defaults to 3.
  • position_embedding_size: int. Number of learnable entries in each (x or y) position embedding table. Should be at least image_size // patch_size. Defaults to 1024.
  • layer_norm_epsilon: float. Epsilon for layer normalisations. Defaults to 1e-6.
  • dropout: float. Dropout probability. Defaults to 0.
  • dtype: Compute dtype. Defaults to None (uses Keras global policy).

Example

import numpy as np

vision_encoder = keras_hub.models.Gemma4VisionEncoder(
    image_size=768,
    patch_size=16,
    num_heads=12,
    hidden_dim=768,
    num_layers=12,
    intermediate_dim=3072,
    head_dim=64,
    output_dim=2304,
    pool_size=3,
)
# pixel_values: (batch, num_images, num_patches, patch_dim)
# For a 768x768 image with patch_size=16: num_patches=48*48=2304,
# patch_dim=16*16*3=768.
pixel_values = np.ones((1, 1, 2304, 768), dtype="float32")
# pixel_position_ids: (batch, num_images, num_patches, 2)
pixel_position_ids = np.zeros((1, 1, 2304, 2), dtype="int32")
output = vision_encoder(
    {"pixel_values": pixel_values, "pixel_position_ids": pixel_position_ids}
)
# output.shape == (1, 1, 256, 2304)
# (batch, num_images, pooled_patches, output_dim)