Gemma4Backbone model

[source]

Gemma4Backbone class

keras_hub.models.Gemma4Backbone(
    vocabulary_size,
    image_size,
    num_layers,
    num_query_heads,
    num_key_value_heads,
    hidden_dim,
    intermediate_dim,
    head_dim,
    query_head_dim_normalize=True,
    attention_logit_soft_cap=None,
    final_logit_soft_cap=None,
    use_sliding_window_attention=True,
    sliding_window_size=512,
    sliding_window_pattern=6,
    global_head_dim=None,
    local_rope_scaling_factor=1.0,
    global_rope_scaling_factor=1.0,
    vision_encoder=None,
    audio_encoder=None,
    num_audio_tokens_per_clip=None,
    layer_norm_epsilon=1e-06,
    use_bidirectional_attention=False,
    use_vision_bidirectional_attention=False,
    dropout=0,
    is_embedding_model=False,
    pooling_intermediate_dim=None,
    embedding_dim=None,
    num_kv_shared_layers=0,
    num_global_key_value_heads=None,
    hidden_size_per_layer_input=0,
    vocab_size_per_layer_input=None,
    global_rope_wavelength=None,
    local_rope_wavelength=None,
    global_rope_partial_rotary_factor=1.0,
    use_double_wide_mlp=False,
    enable_moe_block=False,
    num_experts=None,
    expert_intermediate_dim=None,
    num_experts_per_token=8,
    dtype=None,
    **kwargs
)

Gemma4 core network with hyperparameters.

This backbone implements the Gemma4 model architecture. Gemma4 is a multimodal vision-language model (image + text in, text out). The text input is encoded with a scaled embedding layer; images are encoded by a separate vision transformer (Gemma4VisionEncoder). After encoding, image embeddings are placed at the correct positions in the text-embedding sequence, and the combined sequence is processed by transformer decoder layers.

Compared to Gemma3, Gemma4 introduces:

  • Four norms per decoder block — pre + post for both attention and FFW, always enabled (no use_post_*_norm flags).
  • Q / K / V normalisation in attention always on.
  • Attention scaling = 1.0 — Q/K normalisation provides stability instead of the classic 1/sqrt(head_dim) scaling.
  • New vision encoder — uses the same Gemma4 decoder blocks with bidirectional attention, 2D learnable position embeddings, and spatial average-pooling.
  • Smaller default sliding window — 512 tokens (vs. 1 024 in Gemma3).
  • Audio encoder — an optional Universal Speech Model (USM) conformer that encodes mel spectrograms into audio token embeddings.

For a higher-level object for text generation see keras_hub.models.Gemma4CausalLM.

The default constructor gives a fully customised, randomly initialised Gemma4 model. To load preset weights use the from_preset constructor.

Arguments

  • vocabulary_size: int. The size of the token vocabulary.
  • image_size: int. The spatial resolution of images fed to the vision encoder (height = width). Must be divisible by patch_size * pool_size when a vision_encoder is provided.
  • num_layers: int. Number of transformer decoder layers.
  • num_query_heads: int. Number of query heads per attention layer.
  • num_key_value_heads: int. Number of key/value heads (GQA).
  • hidden_dim: int. Hidden state dimension at the end of each layer.
  • intermediate_dim: int. First dense layer output dimension in each FFW sub-block.
  • head_dim: int. Per-head dimension in the decoder attention.
  • query_head_dim_normalize: bool. If True normalise query pre-attention using head_dim; otherwise use hidden_dim / num_query_heads. Unused in Gemma4 (always Q-normalised via q_norm). Kept for API compatibility. Defaults to True.
  • attention_logit_soft_cap: None or float. Tanh soft-cap on attention logits. Defaults to None.
  • final_logit_soft_cap: None or float. Tanh soft-cap on output logits. Defaults to None.
  • use_sliding_window_attention: bool. Whether to use sliding-window attention on the local layers. Defaults to True.
  • sliding_window_size: int. Size of the local attention window. Defaults to 512.
  • sliding_window_pattern: int. Repeat period of the local/global attention pattern. The last layer in each group of this many consecutive layers uses global attention; all others use local (sliding-window) attention. Defaults to 6.
  • global_head_dim: int or None. Per-head dimension used specifically for global attention layers. When None, head_dim is used for all layers. Defaults to None.
  • local_rope_scaling_factor: float. RoPE scaling factor for local layers. Defaults to 1.0.
  • global_rope_scaling_factor: float. RoPE scaling factor for global layers. Defaults to 1.0.
  • global_rope_partial_rotary_factor: float. Fraction of each head dimension that receives rotary position embeddings in global attention layers. Only the first int(factor * head_dim) dimensions are rotated; the remainder are left unchanged (NoPE). Local layers always use full RoPE (factor = 1.0). Defaults to 1.0.
  • vision_encoder: keras_hub.models.Gemma4VisionEncoder or None. When None the model processes no images.
  • audio_encoder: keras_hub.models.Gemma4AudioEncoder or None. When None the model processes no audio.
  • num_audio_tokens_per_clip: int or None. Number of audio soft tokens produced per audio clip (including zero-padded positions). Must be provided when audio_encoder is not None.
  • layer_norm_epsilon: float. Epsilon for all RMS norms. Defaults 1e-6.
  • use_bidirectional_attention: bool. When True the model uses fully bidirectional attention for ALL tokens, e.g. for embedding models. This is distinct from use_vision_bidirectional_attention which only affects vision token attention. Defaults to False.
  • use_vision_bidirectional_attention: bool. When True, vision tokens within the same image attend to each other bidirectionally while text tokens remain causal. Corresponds to HF use_bidirectional_attention: "vision" (present in 26B and 31B models; null for the 2B and 4B models). Defaults to False.
  • dropout: float. Dropout probability. Defaults to 0.
  • is_embedding_model: bool. When True add mean-pooling and dense projection heads for embedding models. Defaults to False.
  • pooling_intermediate_dim: int or None. Intermediate dimension of the first projection in the pooling head. Required when is_embedding_model=True.
  • embedding_dim: int or None. Final embedding dimension. Required when is_embedding_model=True.
  • num_kv_shared_layers: int. Number of trailing decoder layers that share K/V projections with the most recent non-shared layer of the same attention type. Defaults to 0.
  • num_global_key_value_heads: int or None. When set, global attention layers use this many K/V heads instead of num_key_value_heads and enable the K=V projection. Defaults to None.
  • hidden_size_per_layer_input: int. Size of the per-token, per-layer conditioning vector that gates each decoder layer's output. Set to 0 to disable. Defaults to 0.
  • vocab_size_per_layer_input: int or None. Vocabulary size for the per-layer token embedding table. When None falls back to vocabulary_size. Defaults to None.
  • use_double_wide_mlp: bool. When True, KV-shared layers (is_kv_shared_layer=True) use 2 × intermediate_dim for their FFW sub-block. Defaults to False.
  • enable_moe_block: bool. When True, every decoder layer runs a parallel Mixture-of-Experts path alongside the dense FFW path. The two outputs are summed before the shared post-FFW norm. Requires num_experts and expert_intermediate_dim to be set. Defaults to False.
  • num_experts: int or None. Total number of expert MLPs in the MoE bank. Required when enable_moe_block=True. Defaults to None.
  • expert_intermediate_dim: int or None. Intermediate dimension of each expert MLP. Required when enable_moe_block=True. Defaults to None.
  • num_experts_per_token: int. Top-k experts selected per token by the MoE router. Defaults to 8.
  • dtype: string or keras.mixed_precision.DTypePolicy. Compute dtype. Defaults to None.

Example

import numpy as np

# Text-only input.
model = keras_hub.models.Gemma4Backbone(
    vocabulary_size=262144,
    image_size=768,
    num_layers=26,
    num_query_heads=8,
    num_key_value_heads=4,
    hidden_dim=2304,
    intermediate_dim=9216,
    head_dim=256,
    sliding_window_size=512,
    vision_encoder=None,
    dtype="bfloat16",
)
inputs = {
    "token_ids": np.ones((1, 128), dtype="int32"),
    "padding_mask": np.ones((1, 128), dtype="int32"),
}
model(inputs)

[source]

from_preset method

Gemma4Backbone.from_preset(preset, load_weights=True, **kwargs)

Instantiate a keras_hub.models.Backbone from a model preset.

A preset is a directory of configs, weights and other file assets used to save and load a pre-trained model. The preset can be passed as a one of:

  1. a built-in preset identifier like 'bert_base_en'
  2. a Kaggle Models handle like 'kaggle://user/bert/keras/bert_base_en'
  3. a Hugging Face handle like 'hf://user/bert_base_en'
  4. a ModelScope handle like 'modelscope://user/bert_base_en'
  5. a path to a local preset directory like './bert_base_en'

This constructor can be called in one of two ways. Either from the base class like keras_hub.models.Backbone.from_preset(), or from a model class like keras_hub.models.GemmaBackbone.from_preset(). If calling from the base class, the subclass of the returning object will be inferred from the config in the preset directory.

For any Backbone subclass, you can run cls.presets.keys() to list all built-in presets available on the class.

Arguments

  • preset: string. A built-in preset identifier, a Kaggle Models handle, a Hugging Face handle, or a path to a local directory.
  • load_weights: bool. If True, the weights will be loaded into the model architecture. If False, the weights will be randomly initialized.

Examples

# Load a Gemma backbone with pre-trained weights.
model = keras_hub.models.Backbone.from_preset(
    "gemma_2b_en",
)

# Load a Bert backbone with a pre-trained config and random weights.
model = keras_hub.models.Backbone.from_preset(
    "bert_base_en",
    load_weights=False,
)
Preset Parameters Description
gemma4_2b 5.10B Gemma 4 E2B base model: 2.3B effective parameters (5.1B total with Per-Layer Embeddings), 35-layer, audio+vision+text pretrained Gemma4 model. The 'E' denotes effective parameters — PLE gives each decoder layer its own token embedding table, maximizing parameter efficiency for on-device deployment.
gemma4_instruct_2b 5.10B Gemma 4 E2B instruction-tuned model: 2.3B effective parameters (5.1B total with Per-Layer Embeddings), 35-layer, audio+vision+text instruction-tuned Gemma4 model. The 'E' denotes effective parameters — PLE gives each decoder layer its own token embedding table, maximizing parameter efficiency for on-device deployment.
gemma4_4b 7.90B Gemma 4 E4B base model: 4.5B effective parameters (7.9B total with Per-Layer Embeddings), 42-layer, audio+vision+text pretrained Gemma4 model. The 'E' denotes effective parameters — PLE gives each decoder layer its own token embedding table, maximizing parameter efficiency for on-device deployment.
gemma4_instruct_4b 7.90B Gemma 4 E4B instruction-tuned model: 4.5B effective parameters (7.9B total with Per-Layer Embeddings), 42-layer, audio+vision+text instruction-tuned Gemma4 model. The 'E' denotes effective parameters — PLE gives each decoder layer its own token embedding table, maximizing parameter efficiency for on-device deployment.
gemma4_26b_a4b 26.00B Gemma 4 26B A4B base model: Mixture-of-Experts (MoE) model with 26B total parameters and only 4B active parameters per forward pass, 30-layer, vision+text pretrained Gemma4 model. The 'A' denotes active parameters — by activating only a 4B subset during inference, this MoE model runs nearly as fast as a dense 4B model.
gemma4_instruct_26b_a4b 26.00B Gemma 4 26B A4B instruction-tuned model: Mixture-of-Experts (MoE) model with 26B total parameters and only 4B active parameters per forward pass, 30-layer, vision+text instruction-tuned Gemma4 model. The 'A' denotes active parameters — by activating only a 4B subset during inference, this MoE model runs nearly as fast as a dense 4B model.
gemma4_31b 31.00B Gemma 4 31B base model: 31B parameter, 60-layer, dense vision+text pretrained Gemma4 model. The dense model in the Gemma 4 family, offering maximum quality for deployments where inference speed is less of a constraint.
gemma4_instruct_31b 31.00B Gemma 4 31B instruction-tuned model: 31B parameter, 60-layer, dense vision+text instruction-tuned Gemma4 model. The dense model in the Gemma 4 family, offering maximum quality for deployments where inference speed is less of a constraint.

token_embedding property

keras_hub.models.Gemma4Backbone.token_embedding

A keras.layers.Embedding instance for embedding token ids.

This layer embeds integer token ids to the hidden dim of the model.


[source]

enable_lora method

Gemma4Backbone.enable_lora(rank, target_layer_names=None)

Enable Lora on the backbone.

Calling this method will freeze all weights on the backbone, while enabling Lora on the query & value EinsumDense layers of the attention layers.

Arguments

  • rank: The rank of the LoRA factorization.
  • target_layer_names: A list of strings, the names of the layers to apply LoRA to. If None, this will be populated with the default LoRA layer names as returned by backbone.default_lora_layer_names().