Gemma4Backbone classkeras_hub.models.Gemma4Backbone(
vocabulary_size,
image_size,
num_layers,
num_query_heads,
num_key_value_heads,
hidden_dim,
intermediate_dim,
head_dim,
query_head_dim_normalize=True,
attention_logit_soft_cap=None,
final_logit_soft_cap=None,
use_sliding_window_attention=True,
sliding_window_size=512,
sliding_window_pattern=6,
global_head_dim=None,
local_rope_scaling_factor=1.0,
global_rope_scaling_factor=1.0,
vision_encoder=None,
audio_encoder=None,
num_audio_tokens_per_clip=None,
layer_norm_epsilon=1e-06,
use_bidirectional_attention=False,
use_vision_bidirectional_attention=False,
dropout=0,
is_embedding_model=False,
pooling_intermediate_dim=None,
embedding_dim=None,
num_kv_shared_layers=0,
num_global_key_value_heads=None,
hidden_size_per_layer_input=0,
vocab_size_per_layer_input=None,
global_rope_wavelength=None,
local_rope_wavelength=None,
global_rope_partial_rotary_factor=1.0,
use_double_wide_mlp=False,
enable_moe_block=False,
num_experts=None,
expert_intermediate_dim=None,
num_experts_per_token=8,
dtype=None,
**kwargs
)
Gemma4 core network with hyperparameters.
This backbone implements the Gemma4 model architecture. Gemma4 is a
multimodal vision-language model (image + text in, text out). The text
input is encoded with a scaled embedding layer; images are encoded by a
separate vision transformer (Gemma4VisionEncoder). After encoding, image
embeddings are placed at the correct positions in the text-embedding
sequence, and the combined sequence is processed by transformer decoder
layers.
Compared to Gemma3, Gemma4 introduces:
use_post_*_norm flags).1/sqrt(head_dim) scaling.For a higher-level object for text generation see
keras_hub.models.Gemma4CausalLM.
The default constructor gives a fully customised, randomly initialised
Gemma4 model. To load preset weights use the from_preset constructor.
Arguments
patch_size * pool_size when a vision_encoder is provided.True normalise query pre-attention
using head_dim; otherwise use hidden_dim / num_query_heads.
Unused in Gemma4 (always Q-normalised via q_norm). Kept for
API compatibility. Defaults to True.None or float. Tanh soft-cap on attention
logits. Defaults to None.None or float. Tanh soft-cap on output logits.
Defaults to None.True.512.6.None. Per-head dimension used specifically
for global attention layers. When None, head_dim is used
for all layers. Defaults to None.1.0.1.0.int(factor * head_dim) dimensions are rotated; the remainder are
left unchanged (NoPE). Local layers always use full RoPE
(factor = 1.0). Defaults to 1.0.keras_hub.models.Gemma4VisionEncoder or None. When
None the model processes no images.keras_hub.models.Gemma4AudioEncoder or None. When
None the model processes no audio.None. Number of audio soft tokens
produced per audio clip (including zero-padded positions). Must be
provided when audio_encoder is not None.1e-6.True the model uses fully
bidirectional attention for ALL tokens, e.g. for embedding
models. This is distinct from use_vision_bidirectional_attention
which only affects vision token attention. Defaults to False.True, vision tokens
within the same image attend to each other bidirectionally while
text tokens remain causal. Corresponds to HF
use_bidirectional_attention: "vision" (present in 26B and 31B
models; null for the 2B and 4B models). Defaults to False.0.True add mean-pooling and dense
projection heads for embedding models. Defaults to False.None. Intermediate dimension of the
first projection in the pooling head. Required when
is_embedding_model=True.None. Final embedding dimension. Required when
is_embedding_model=True.0.None. When set, global attention
layers use this many K/V heads instead of num_key_value_heads
and enable the K=V projection. Defaults to None.0 to disable. Defaults to 0.None. Vocabulary size for the
per-layer token embedding table. When None falls back to
vocabulary_size. Defaults to None.True, KV-shared layers
(is_kv_shared_layer=True) use 2 × intermediate_dim for their
FFW sub-block. Defaults to False.True, every decoder layer runs a
parallel Mixture-of-Experts path alongside the dense FFW path.
The two outputs are summed before the shared post-FFW norm.
Requires num_experts and expert_intermediate_dim to be set.
Defaults to False.None. Total number of expert MLPs in the MoE
bank. Required when enable_moe_block=True. Defaults to None.None. Intermediate dimension of each
expert MLP. Required when enable_moe_block=True.
Defaults to None.8.keras.mixed_precision.DTypePolicy. Compute dtype.
Defaults to None.Example
import numpy as np
# Text-only input.
model = keras_hub.models.Gemma4Backbone(
vocabulary_size=262144,
image_size=768,
num_layers=26,
num_query_heads=8,
num_key_value_heads=4,
hidden_dim=2304,
intermediate_dim=9216,
head_dim=256,
sliding_window_size=512,
vision_encoder=None,
dtype="bfloat16",
)
inputs = {
"token_ids": np.ones((1, 128), dtype="int32"),
"padding_mask": np.ones((1, 128), dtype="int32"),
}
model(inputs)
from_preset methodGemma4Backbone.from_preset(preset, load_weights=True, **kwargs)
Instantiate a keras_hub.models.Backbone from a model preset.
A preset is a directory of configs, weights and other file assets used
to save and load a pre-trained model. The preset can be passed as a
one of:
'bert_base_en''kaggle://user/bert/keras/bert_base_en''hf://user/bert_base_en''modelscope://user/bert_base_en''./bert_base_en'This constructor can be called in one of two ways. Either from the base
class like keras_hub.models.Backbone.from_preset(), or from
a model class like keras_hub.models.GemmaBackbone.from_preset().
If calling from the base class, the subclass of the returning object
will be inferred from the config in the preset directory.
For any Backbone subclass, you can run cls.presets.keys() to list
all built-in presets available on the class.
Arguments
True, the weights will be loaded into the
model architecture. If False, the weights will be randomly
initialized.Examples
# Load a Gemma backbone with pre-trained weights.
model = keras_hub.models.Backbone.from_preset(
"gemma_2b_en",
)
# Load a Bert backbone with a pre-trained config and random weights.
model = keras_hub.models.Backbone.from_preset(
"bert_base_en",
load_weights=False,
)
| Preset | Parameters | Description |
|---|---|---|
| gemma4_2b | 5.10B | Gemma 4 E2B base model: 2.3B effective parameters (5.1B total with Per-Layer Embeddings), 35-layer, audio+vision+text pretrained Gemma4 model. The 'E' denotes effective parameters — PLE gives each decoder layer its own token embedding table, maximizing parameter efficiency for on-device deployment. |
| gemma4_instruct_2b | 5.10B | Gemma 4 E2B instruction-tuned model: 2.3B effective parameters (5.1B total with Per-Layer Embeddings), 35-layer, audio+vision+text instruction-tuned Gemma4 model. The 'E' denotes effective parameters — PLE gives each decoder layer its own token embedding table, maximizing parameter efficiency for on-device deployment. |
| gemma4_4b | 7.90B | Gemma 4 E4B base model: 4.5B effective parameters (7.9B total with Per-Layer Embeddings), 42-layer, audio+vision+text pretrained Gemma4 model. The 'E' denotes effective parameters — PLE gives each decoder layer its own token embedding table, maximizing parameter efficiency for on-device deployment. |
| gemma4_instruct_4b | 7.90B | Gemma 4 E4B instruction-tuned model: 4.5B effective parameters (7.9B total with Per-Layer Embeddings), 42-layer, audio+vision+text instruction-tuned Gemma4 model. The 'E' denotes effective parameters — PLE gives each decoder layer its own token embedding table, maximizing parameter efficiency for on-device deployment. |
| gemma4_26b_a4b | 26.00B | Gemma 4 26B A4B base model: Mixture-of-Experts (MoE) model with 26B total parameters and only 4B active parameters per forward pass, 30-layer, vision+text pretrained Gemma4 model. The 'A' denotes active parameters — by activating only a 4B subset during inference, this MoE model runs nearly as fast as a dense 4B model. |
| gemma4_instruct_26b_a4b | 26.00B | Gemma 4 26B A4B instruction-tuned model: Mixture-of-Experts (MoE) model with 26B total parameters and only 4B active parameters per forward pass, 30-layer, vision+text instruction-tuned Gemma4 model. The 'A' denotes active parameters — by activating only a 4B subset during inference, this MoE model runs nearly as fast as a dense 4B model. |
| gemma4_31b | 31.00B | Gemma 4 31B base model: 31B parameter, 60-layer, dense vision+text pretrained Gemma4 model. The dense model in the Gemma 4 family, offering maximum quality for deployments where inference speed is less of a constraint. |
| gemma4_instruct_31b | 31.00B | Gemma 4 31B instruction-tuned model: 31B parameter, 60-layer, dense vision+text instruction-tuned Gemma4 model. The dense model in the Gemma 4 family, offering maximum quality for deployments where inference speed is less of a constraint. |
token_embedding propertykeras_hub.models.Gemma4Backbone.token_embedding
A keras.layers.Embedding instance for embedding token ids.
This layer embeds integer token ids to the hidden dim of the model.
enable_lora methodGemma4Backbone.enable_lora(rank, target_layer_names=None)
Enable Lora on the backbone.
Calling this method will freeze all weights on the backbone,
while enabling Lora on the query & value EinsumDense layers
of the attention layers.
Arguments
None, this will be populated with the
default LoRA layer names as returned by
backbone.default_lora_layer_names().