Gemma3Backbone model

[source]

Gemma3Backbone class

keras_hub.models.Gemma3Backbone(
    vocabulary_size,
    image_size,
    num_layers,
    num_query_heads,
    num_key_value_heads,
    hidden_dim,
    intermediate_dim,
    head_dim,
    query_head_dim_normalize=True,
    use_query_key_norm=True,
    use_post_ffw_norm=False,
    use_post_attention_norm=False,
    attention_logit_soft_cap=None,
    final_logit_soft_cap=None,
    use_sliding_window_attention=False,
    sliding_window_size=1024,
    local_rope_scaling_factor=1.0,
    global_rope_scaling_factor=1.0,
    vision_encoder=None,
    layer_norm_epsilon=1e-06,
    dropout=0,
    dtype=None,
    **kwargs
)

Gemma3 core network with hyperparameters.

This backbone implements the Gemma3 model architecture. Gemma3 is a vision-language model (image-text in, text out). The text input is encoded using an embedding layer; images are encoded using a vision transformer (ViT). After encoding these two modalities, the image embeddings are placed in the correct position in the text embedding sequence. The mixed sequence of embeddings is then passed through transformer decoder layers.

For a higher-level object for text-generation, see keras_hub.models.Gemma3CausalLM.

The default constructor gives a fully customizable, randomly initialized Gemma3 model with any vision encoder, number of heads, embedding dimensions, and equivalent configuration for the decoder layers. To load preset architectures and weights, use the from_preset constructor.

Arguments

  • vocabulary_size: int. The size of the token vocabulary.
  • image_size: int. The resolution of the image in both width and height. The input images must be square.
  • num_layers: int. The number of transformer mixed decoder layers.
  • num_query_heads: int. The number of heads for the query projections in the mixed decoder attention layer.
  • num_key_value_heads: int. The number of heads for the key and value projections in the mixed decoder attention layers.
  • hidden_dim: int. The size of the transformer hidden state at the end of each mixed transformer layer.
  • intermediate_dim: int. The output dimension of the first Dense layer in a two-layer feedforward network for each transformer decoder block.
  • head_dim: int. The size of each attention head in the mixed decoder.
  • query_head_dim_normalize: boolean. If True normalize the query before attention with head_dim. If False, normalize the query with hidden_dim / num_query_heads. Defaults to True.
  • use_query_key_norm: bool. If True, apply a RMS Norm layer to query and key before projecting them. Defaults to True.
  • use_post_ffw_norm: boolean. Whether to normalize after the feedforward block. Defaults to False.
  • use_post_attention_norm: boolean. Whether to normalize after the attention block. Defaults to False.
  • attention_logit_soft_cap: None or int. Soft cap for the attention logits. Defaults to None.
  • final_logit_soft_cap: None or int. Soft cap for the final logits. Defaults to None.
  • use_sliding_window_attention: boolean. Whether to use sliding local window attention. Defaults to False.
  • sliding_window_size: int. Size of the sliding local window. Defaults to 4096.
  • vision_encoder: A Gemma3VisionEncoder instance. call() takes in images and returns corresponding sequence of embeddings. If None, the model is a text-only model.
  • layer_norm_epsilon: float. The epsilon value user for every layer norm in all transformer blocks. Defaults to 1e-6.
  • dropout: float. Dropout probability for the Transformer decoder blocks. Defaults to 0.
  • dtype: string or keras.mixed_precision.DTypePolicy. The dtype to use for the models computations and weights. Note that some computations, such as softmax and layer normalization will always be done in float32 precision regardless of dtype. Defaults to bfloat16.

Example

# === Language Gemma3 model ===
input_data = {}
input_data["token_ids"] = np.ones(shape=(1, 300), dtype="int32")
input_data["padding_mask"] = (
    np.expand_dims(np.array([1] * 280 + [0] * (300 - 280)), axis=0)
    .astype(bool)
)

# Pretrained Gemma3 decoder.
model = keras_hub.models.Gemma3Backbone.from_preset(
    "gemma3_instruct_4b_text"
)
model(input_data)

# Randomly initialized Gemma3 decoder with a custom config.
model = keras_hub.models.Gemma3Backbone(
    vocabulary_size=262144,
    image_size=896,
    num_layers=34,
    num_query_heads=8,
    num_key_value_heads=4,
    hidden_dim=2560,
    intermediate_dim=10240,
    head_dim=256,
    query_head_dim_normalize=True,
    use_post_ffw_norm=True,
    use_post_attention_norm=True,
    final_logit_soft_cap=None,
    attention_logit_soft_cap=None,
    sliding_window_size=1024,
    use_sliding_window_attention=True,
    vision_encoder=None,
    layer_norm_epsilon=1e-06,
    dtype="bfloat16",
)
model(input_data)

# === Vision + Language Gemma3 model ===
input_data = {}
input_data["images"] = np.ones(shape=(1, 1, 896, 896, 3))
input_data["token_ids"] = np.ones(shape=(1, 300), dtype="int32")
# images after the text part of the sequence.
input_data["vision_mask"] = np.expand_dims(
    np.array([0] * 30 + [1] * 256 + [0] * 14),
    axis=0,
).astype(bool)
input_data["vision_indices"] = (
    np.expand_dims(np.arange(30, 286), axis=0)
)
input_data["padding_mask"] = (
    np.expand_dims(np.array([1] * 286 + [0] * (300 - 286)), axis=0)
    .astype(bool)
)

# Pretrained Gemma3 decoder.
model = keras_hub.models.Gemma3Backbone.from_preset("gemma3_instruct_4b")
model(input_data)

# Randomly initialized Gemma3 decoder with a custom config.
vision_encoder = Gemma3VisionEncoder(
    image_size=896,
    patch_size=14,
    num_heads=16,
    hidden_dim=1152,
    num_layers=27,
    intermediate_dim=4304,
    output_dim=2560,
    pool_size=4,
    layer_norm_epsilon=1e-6,
    dtype="float32",
)

model = keras_hub.models.Gemma3Backbone(
    vocabulary_size=262144,
    image_size=896,
    num_layers=34,
    num_query_heads=8,
    num_key_value_heads=4,
    hidden_dim=2560,
    intermediate_dim=10240,
    head_dim=256,
    query_head_dim_normalize=True,
    use_post_ffw_norm=True,
    use_post_attention_norm=True,
    final_logit_soft_cap=None,
    attention_logit_soft_cap=None,
    sliding_window_size=1024,
    use_sliding_window_attention=True,
    vision_encoder=vision_encoder,
    layer_norm_epsilon=1e-06,
    dtype="bfloat16"
)
model(input_data)

[source]

from_preset method

Gemma3Backbone.from_preset(preset, load_weights=True, **kwargs)

Instantiate a keras_hub.models.Backbone from a model preset.

A preset is a directory of configs, weights and other file assets used to save and load a pre-trained model. The preset can be passed as a one of:

  1. a built-in preset identifier like 'bert_base_en'
  2. a Kaggle Models handle like 'kaggle://user/bert/keras/bert_base_en'
  3. a Hugging Face handle like 'hf://user/bert_base_en'
  4. a path to a local preset directory like './bert_base_en'

This constructor can be called in one of two ways. Either from the base class like keras_hub.models.Backbone.from_preset(), or from a model class like keras_hub.models.GemmaBackbone.from_preset(). If calling from the base class, the subclass of the returning object will be inferred from the config in the preset directory.

For any Backbone subclass, you can run cls.presets.keys() to list all built-in presets available on the class.

Arguments

  • preset: string. A built-in preset identifier, a Kaggle Models handle, a Hugging Face handle, or a path to a local directory.
  • load_weights: bool. If True, the weights will be loaded into the model architecture. If False, the weights will be randomly initialized.

Examples

# Load a Gemma backbone with pre-trained weights.
model = keras_hub.models.Backbone.from_preset(
    "gemma_2b_en",
)

# Load a Bert backbone with a pre-trained config and random weights.
model = keras_hub.models.Backbone.from_preset(
    "bert_base_en",
    load_weights=False,
)
Preset Parameters Description
gemma3_1b 999.89M 1 billion parameter, 26-layer, text-only pretrained Gemma3 model.
gemma3_instruct_1b 999.89M 1 billion parameter, 26-layer, text-only instruction-tuned Gemma3 model.
gemma3_4b_text 3.88B 4 billion parameter, 34-layer, text-only pretrained Gemma3 model.
gemma3_instruct_4b_text 3.88B 4 billion parameter, 34-layer, text-only instruction-tuned Gemma3 model.
gemma3_4b 4.30B 4 billion parameter, 34-layer, vision+text pretrained Gemma3 model.
gemma3_instruct_4b 4.30B 4 billion parameter, 34-layer, vision+text instruction-tuned Gemma3 model.
gemma3_12b_text 11.77B 12 billion parameter, 48-layer, text-only pretrained Gemma3 model.
gemma3_instruct_12b_text 11.77B 12 billion parameter, 48-layer, text-only instruction-tuned Gemma3 model.
gemma3_12b 12.19B 12 billion parameter, 48-layer, vision+text pretrained Gemma3 model.
gemma3_instruct_12b 12.19B 12 billion parameter, 48-layer, vision+text instruction-tuned Gemma3 model.
gemma3_27b_text 27.01B 27 billion parameter, 62-layer, text-only pretrained Gemma3 model.
gemma3_instruct_27b_text 27.01B 27 billion parameter, 62-layer, text-only instruction-tuned Gemma3 model.
gemma3_27b 27.43B 27 billion parameter, 62-layer, vision+text pretrained Gemma3 model.
gemma3_instruct_27b 27.43B 27 billion parameter, 62-layer, vision+text instruction-tuned Gemma3 model.

token_embedding property

keras_hub.models.Gemma3Backbone.token_embedding

A keras.layers.Embedding instance for embedding token ids.

This layer embeds integer token ids to the hidden dim of the model.


[source]

enable_lora method

Gemma3Backbone.enable_lora(rank, target_names=None)

Enable Lora on the backbone.

Calling this method will freeze all weights on the backbone, while enabling Lora on the query & value EinsumDense layers of the attention layers.