PaliGemmaBackbone model


PaliGemmaBackbone class


PaliGemma core network with hyperparameters.

This backbone implements the mixed-modality PaliGemma architecture. It contains a Visual Transformer network, as well as text token embedding layer, followed by a backend-agnostic concatenation operation to construct a sequence of representations of mixed type embeddings (visual and textual). Then, the concatenated sequence is passed through a series of Mixed Modality Decoder Blocks. The returned value from calling this model represents probabilistic values for output tokens.

For a higher-level object for text-generation, see keras_hub.models.PaliGemmaCausalLM.

The default constructor gives a fully customizable, randomly initialized PaliGemma model with any number of vit layers, heads, embedding dimensions, and equivalent configuration for Paligemma Decoder layers. To load preset architectures and weights, use the from_preset constructor.


  • vocabulary_size: int. The size of the token vocabulary.
  • image_size: int. The resolution of the image in both width and height. Note: input images must be square.
  • num_layers: int. The number of transformer mixed decoder layers.
  • num_query_heads: int. The number of heads for the query projections in the mixed decoder attention layer.
  • num_key_value_heads: int. The number of heads for the key and value projections in the mixed decoder attention layers.
  • hidden_dim: int. The size of the transformer hidden state at the end of each mixed transformer layer.
  • intermediate_dim: int. The output dimension of the first Dense layer in a two-layer feedforward network for each transformer decoder block.
  • head_dim: int. The size of each attention head in the mixed decoder.
  • vit_patch_size: int. The size of each square patch in the input image.
  • vit_num_heads: int. The number of attention heads for the vision (image) transformer encoder.
  • vit_hidden_dim: int. The size of the transformer hidden state at the end of each vision transformer layer.
  • vit_num_layers: int. The number of vision transformer layers.
  • vit_intermediate_dim: int. The output dimension of the first Dense layer in a two-layer feedforward network for vision transformer. Defaults to 4304.
  • vit_pooling: None or string. The encoded vision embeddings are pooled using the specified polling setting. The accepted values are "map", "gap", "0" or None. Defaults to None.
  • vit_classifier_activation: activation function. The activation that is used for final output classification in the vision transformer. Defaults to None.
  • vit_name: string. The name used for vision transformer layers.
  • query_head_dim_normalize: boolean. If True normalize the query before attention with head_dim. If False, normalize the query with hidden_dim / num_query_heads. Defaults to True.
  • use_post_ffw_norm: boolean. Whether to normalize after the feedforward block. Defaults to False.
  • use_post_attention_norm: boolean. Whether to normalize after the attention block. Defaults to False.
  • attention_logit_soft_cap: None or int. Soft cap for the attention logits. Defaults to None.
  • final_logit_soft_cap: None or int. Soft cap for the final logits. Defaults to None.
  • use_sliding_window_attention: boolean. Whether to use sliding local window attention. Defaults to False.
  • sliding_window_size: int. Size of the sliding local window. Defaults to 4096.
  • layer_norm_epsilon: float. The epsilon value user for every layer norm in all transformer blocks. Defaults to 1e-6.
  • dropout: float. Dropout probability for the Transformer decoder blocks. Defaults to 0.
  • dtype: string or keras.mixed_precision.DTypePolicy. The dtype to use for the models computations and weights. Note that some computations, such as softmax and layer normalization will always be done a float32 precision regardless of dtype.


input_data = {
    "token_ids": np.ones(shape=(1, 12), dtype="int32"),
    "images": np.random.uniform(size=(1, 224, 224, 3)),
    "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]),

# Pretrained PaliGemma decoder.
model = keras_hub.models.PaliGemmaBackbone.from_preset("pali_gemma_mix_224")

# Randomly initialized PaliGemma decoder with custom config.
model = keras_hub.models.PaliGemmaBackbone(


from_preset method

PaliGemmaBackbone.from_preset(preset, load_weights=True, **kwargs)

Instantiate a keras_hub.models.Backbone from a model preset.

A preset is a directory of configs, weights and other file assets used to save and load a pre-trained model. The preset can be passed as a one of:

  1. a built-in preset identifier like 'bert_base_en'
  2. a Kaggle Models handle like 'kaggle://user/bert/keras/bert_base_en'
  3. a Hugging Face handle like 'hf://user/bert_base_en'
  4. a path to a local preset directory like './bert_base_en'

This constructor can be called in one of two ways. Either from the base class like keras_hub.models.Backbone.from_preset(), or from a model class like keras_hub.models.GemmaBackbone.from_preset(). If calling from the base class, the subclass of the returning object will be inferred from the config in the preset directory.

For any Backbone subclass, you can run cls.presets.keys() to list all built-in presets available on the class.


  • preset: string. A built-in preset identifier, a Kaggle Models handle, a Hugging Face handle, or a path to a local directory.
  • load_weights: bool. If True, the weights will be loaded into the model architecture. If False, the weights will be randomly initialized.


# Load a Gemma backbone with pre-trained weights.
model = keras_hub.models.Backbone.from_preset(

# Load a Bert backbone with a pre-trained config and random weights.
model = keras_hub.models.Backbone.from_preset(
Preset Parameters Description
pali_gemma_3b_mix_224 2.92B image size 224, mix fine tuned, text sequence length is 256
pali_gemma_3b_224 2.92B image size 224, pre trained, text sequence length is 128
pali_gemma_3b_mix_448 2.92B image size 448, mix fine tuned, text sequence length is 512
pali_gemma_3b_448 2.92B image size 448, pre trained, text sequence length is 512
pali_gemma_3b_896 2.93B image size 896, pre trained, text sequence length is 512
pali_gemma2_pt_3b_224 3.03B 3 billion parameter, image size 224, 27-layer for SigLIP-So400m vision encoder and 26-layer Gemma2 2B lanuage model. This model has been pre-trained on a mixture of datasets.
pali_gemma2_3b_ft_docci_448 3.03B 3 billion parameter, image size 448, 27-layer for SigLIP-So400m vision encoder and 26-layer Gemma2 2B lanuage model. This model has been fine-tuned on the DOCCI dataset for improved descriptions with fine-grained details.
pali_gemma2_pt_3b_448 3.03B 3 billion parameter, image size 448, 27-layer for SigLIP-So400m vision encoder and 26-layer Gemma2 2B lanuage model. This model has been pre-trained on a mixture of datasets.
pali_gemma2_pt_3b_896 3.04B 3 billion parameter, image size 896, 27-layer for SigLIP-So400m vision encoder and 26-layer Gemma2 2B lanuage model. This model has been pre-trained on a mixture of datasets.
pali_gemma2_pt_10b_224 9.66B 10 billion parameter, image size 224, 27-layer for SigLIP-So400m vision encoder and 42-layer Gemma2 9B lanuage model. This model has been pre-trained on a mixture of datasets.
pali_gemma2_pt_28b_224 9.66B 28 billion parameter, image size 224, 27-layer for SigLIP-So400m vision encoder and 46-layer Gemma2 27B lanuage model. This model has been pre-trained on a mixture of datasets.
pali_gemma2_10b_ft_docci_448 9.66B 10 billion parameter, 27-layer for SigLIP-So400m vision encoder and 42-layer Gemma2 9B lanuage model. This model has been fine-tuned on the DOCCI dataset for improved descriptions with fine-grained details.
pali_gemma2_pt_10b_448 9.66B 10 billion parameter, image size 448, 27-layer for SigLIP-So400m vision encoder and 42-layer Gemma2 9B lanuage model. This model has been pre-trained on a mixture of datasets.
pali_gemma2_pt_28b_448 9.66B 28 billion parameter, image size 448, 27-layer for SigLIP-So400m vision encoder and 46-layer Gemma2 27B lanuage model. This model has been pre-trained on a mixture of datasets.
pali_gemma2_pt_10b_896 9.67B 10 billion parameter, image size 896, 27-layer for SigLIP-So400m vision encoder and 42-layer Gemma2 9B lanuage model. This model has been pre-trained on a mixture of datasets.
pali_gemma2_pt_28b_896 9.67B 28 billion parameter, image size 896, 27-layer for SigLIP-So400m vision encoder and 46-layer Gemma2 27B lanuage model. This model has been pre-trained on a mixture of datasets.

token_embedding property


A keras.layers.Embedding instance for embedding token ids.

This layer embeds integer token ids to the hidden dim of the model.