Gemma3Backbone
classkeras_hub.models.Gemma3Backbone(
vocabulary_size,
image_size,
num_layers,
num_query_heads,
num_key_value_heads,
hidden_dim,
intermediate_dim,
head_dim,
query_head_dim_normalize=True,
use_query_key_norm=True,
use_post_ffw_norm=False,
use_post_attention_norm=False,
attention_logit_soft_cap=None,
final_logit_soft_cap=None,
use_sliding_window_attention=False,
sliding_window_size=1024,
local_rope_scaling_factor=1.0,
global_rope_scaling_factor=1.0,
vision_encoder=None,
layer_norm_epsilon=1e-06,
dropout=0,
dtype=None,
**kwargs
)
Gemma3 core network with hyperparameters.
This backbone implements the Gemma3 model architecture. Gemma3 is a vision-language model (image-text in, text out). The text input is encoded using an embedding layer; images are encoded using a vision transformer (ViT). After encoding these two modalities, the image embeddings are placed in the correct position in the text embedding sequence. The mixed sequence of embeddings is then passed through transformer decoder layers.
For a higher-level object for text-generation, see
keras_hub.models.Gemma3CausalLM
.
The default constructor gives a fully customizable, randomly initialized
Gemma3 model with any vision encoder, number of heads, embedding dimensions,
and equivalent configuration for the decoder layers. To load preset
architectures and weights, use the from_preset
constructor.
Arguments
True
normalize the query before
attention with head_dim
. If False
, normalize the query with
hidden_dim / num_query_heads
. Defaults to True
.True
, apply a RMS Norm layer to query and
key before projecting them. Defaults to True
.False
.False
.None
or int. Soft cap for the attention
logits. Defaults to None
.None
or int. Soft cap for the final logits.
Defaults to None
.False
.4096
.Gemma3VisionEncoder
instance. call()
takes in images and returns corresponding sequence of embeddings. If
None
, the model is a text-only model.1e-6
.0
.keras.mixed_precision.DTypePolicy
. The dtype to use
for the models computations and weights. Note that some
computations, such as softmax and layer normalization will always
be done in float32 precision regardless of dtype. Defaults to
bfloat16
.Example
# === Language Gemma3 model ===
input_data = {}
input_data["token_ids"] = np.ones(shape=(1, 300), dtype="int32")
input_data["padding_mask"] = (
np.expand_dims(np.array([1] * 280 + [0] * (300 - 280)), axis=0)
.astype(bool)
)
# Pretrained Gemma3 decoder.
model = keras_hub.models.Gemma3Backbone.from_preset(
"gemma3_instruct_4b_text"
)
model(input_data)
# Randomly initialized Gemma3 decoder with a custom config.
model = keras_hub.models.Gemma3Backbone(
vocabulary_size=262144,
image_size=896,
num_layers=34,
num_query_heads=8,
num_key_value_heads=4,
hidden_dim=2560,
intermediate_dim=10240,
head_dim=256,
query_head_dim_normalize=True,
use_post_ffw_norm=True,
use_post_attention_norm=True,
final_logit_soft_cap=None,
attention_logit_soft_cap=None,
sliding_window_size=1024,
use_sliding_window_attention=True,
vision_encoder=None,
layer_norm_epsilon=1e-06,
dtype="bfloat16",
)
model(input_data)
# === Vision + Language Gemma3 model ===
input_data = {}
input_data["images"] = np.ones(shape=(1, 1, 896, 896, 3))
input_data["token_ids"] = np.ones(shape=(1, 300), dtype="int32")
# images after the text part of the sequence.
input_data["vision_mask"] = np.expand_dims(
np.array([0] * 30 + [1] * 256 + [0] * 14),
axis=0,
).astype(bool)
input_data["vision_indices"] = (
np.expand_dims(np.arange(30, 286), axis=0)
)
input_data["padding_mask"] = (
np.expand_dims(np.array([1] * 286 + [0] * (300 - 286)), axis=0)
.astype(bool)
)
# Pretrained Gemma3 decoder.
model = keras_hub.models.Gemma3Backbone.from_preset("gemma3_instruct_4b")
model(input_data)
# Randomly initialized Gemma3 decoder with a custom config.
vision_encoder = Gemma3VisionEncoder(
image_size=896,
patch_size=14,
num_heads=16,
hidden_dim=1152,
num_layers=27,
intermediate_dim=4304,
output_dim=2560,
pool_size=4,
layer_norm_epsilon=1e-6,
dtype="float32",
)
model = keras_hub.models.Gemma3Backbone(
vocabulary_size=262144,
image_size=896,
num_layers=34,
num_query_heads=8,
num_key_value_heads=4,
hidden_dim=2560,
intermediate_dim=10240,
head_dim=256,
query_head_dim_normalize=True,
use_post_ffw_norm=True,
use_post_attention_norm=True,
final_logit_soft_cap=None,
attention_logit_soft_cap=None,
sliding_window_size=1024,
use_sliding_window_attention=True,
vision_encoder=vision_encoder,
layer_norm_epsilon=1e-06,
dtype="bfloat16"
)
model(input_data)
from_preset
methodGemma3Backbone.from_preset(preset, load_weights=True, **kwargs)
Instantiate a keras_hub.models.Backbone
from a model preset.
A preset is a directory of configs, weights and other file assets used
to save and load a pre-trained model. The preset
can be passed as a
one of:
'bert_base_en'
'kaggle://user/bert/keras/bert_base_en'
'hf://user/bert_base_en'
'./bert_base_en'
This constructor can be called in one of two ways. Either from the base
class like keras_hub.models.Backbone.from_preset()
, or from
a model class like keras_hub.models.GemmaBackbone.from_preset()
.
If calling from the base class, the subclass of the returning object
will be inferred from the config in the preset directory.
For any Backbone
subclass, you can run cls.presets.keys()
to list
all built-in presets available on the class.
Arguments
True
, the weights will be loaded into the
model architecture. If False
, the weights will be randomly
initialized.Examples
# Load a Gemma backbone with pre-trained weights.
model = keras_hub.models.Backbone.from_preset(
"gemma_2b_en",
)
# Load a Bert backbone with a pre-trained config and random weights.
model = keras_hub.models.Backbone.from_preset(
"bert_base_en",
load_weights=False,
)
Preset | Parameters | Description |
---|---|---|
gemma3_1b | 999.89M | 1 billion parameter, 26-layer, text-only pretrained Gemma3 model. |
gemma3_instruct_1b | 999.89M | 1 billion parameter, 26-layer, text-only instruction-tuned Gemma3 model. |
gemma3_4b_text | 3.88B | 4 billion parameter, 34-layer, text-only pretrained Gemma3 model. |
gemma3_instruct_4b_text | 3.88B | 4 billion parameter, 34-layer, text-only instruction-tuned Gemma3 model. |
gemma3_4b | 4.30B | 4 billion parameter, 34-layer, vision+text pretrained Gemma3 model. |
gemma3_instruct_4b | 4.30B | 4 billion parameter, 34-layer, vision+text instruction-tuned Gemma3 model. |
gemma3_12b_text | 11.77B | 12 billion parameter, 48-layer, text-only pretrained Gemma3 model. |
gemma3_instruct_12b_text | 11.77B | 12 billion parameter, 48-layer, text-only instruction-tuned Gemma3 model. |
gemma3_12b | 12.19B | 12 billion parameter, 48-layer, vision+text pretrained Gemma3 model. |
gemma3_instruct_12b | 12.19B | 12 billion parameter, 48-layer, vision+text instruction-tuned Gemma3 model. |
gemma3_27b_text | 27.01B | 27 billion parameter, 62-layer, text-only pretrained Gemma3 model. |
gemma3_instruct_27b_text | 27.01B | 27 billion parameter, 62-layer, text-only instruction-tuned Gemma3 model. |
gemma3_27b | 27.43B | 27 billion parameter, 62-layer, vision+text pretrained Gemma3 model. |
gemma3_instruct_27b | 27.43B | 27 billion parameter, 62-layer, vision+text instruction-tuned Gemma3 model. |
token_embedding
propertykeras_hub.models.Gemma3Backbone.token_embedding
A keras.layers.Embedding
instance for embedding token ids.
This layer embeds integer token ids to the hidden dim of the model.
enable_lora
methodGemma3Backbone.enable_lora(rank, target_names=None)
Enable Lora on the backbone.
Calling this method will freeze all weights on the backbone,
while enabling Lora on the query & value EinsumDense
layers
of the attention layers.