Gemma3CausalLMPreprocessor
classkeras_hub.models.Gemma3CausalLMPreprocessor(
tokenizer,
image_converter=None,
sequence_length=1024,
add_start_token=True,
add_end_token=True,
max_images_per_prompt=2,
num_vision_tokens_per_image=256,
**kwargs
)
Gemma3 Causal LM preprocessor.
This preprocessing layer is meant for use with
keras_hub.models.Gemma3CausalLM
. It can be configured in two ways:
text-only and text + vision, based on whether the passed value of
image_converter
is None. For the former, it takes in batches of strings,
whereas for the latter, it takes in batches of images and strings. It
returns outputs in a (x, y, sample_weight)
format, where the y
label is
the next token id in the x
sequence. sample_weight
is 0 for "prompt"
tokens, and 1 for "response" tokens, so that the loss is computed only on
the "response" tokens.
For the text + vision case, this layer replaces instance of
<start_of_image>
token in the prompt with num_vision_tokens_per_image
placeholder tokens. It also returns indices of where these vision tokens
are present so that in the model, image embeddings can be placed in the
right position in the sequence of text embeddings. Note that if
max_images_per_prompt
is 2, you can pass either 0, 1, 2 images per sample.
The value 0 corresponds to text-only input.
For use with generation, the layer also exposes two methods
generate_preprocess()
and generate_postprocess()
. When this preprocessor
is attached to a keras_hub.models.GemmaCausalLM
instance, these methods
will be called implicitly in generate()
. They can also be called
standalone (e.g. to precompute preprocessing inputs for generation in a
separate process).
Arguments
keras_hub.models.GemmaTokenizer
instance.keras_hub.layers.ImageConverter
instance. Defaults
to None
.True
, the preprocessor will prepend the tokenizer
start token to each input sequence. Defaults to True
.True
, the preprocessor will append the tokenizer
end token to each input sequence. Defaults to True
.Call arguments
tf.Tensor
or list of python strings.None
as the layer generates labels.None
as the layer
generates label weights.sequence_length
of
the layer.Examples
# === Language Gemma3 model ===
# Load the preprocessor from a preset.
preprocessor = keras_hub.models.Gemma3CausalLMPreprocessor.from_preset(
"gemma3_instruct_1b"
)
# Unbatched inputs.
preprocessor(
{
"prompts": "What is the capital of India?",
"responses": "New Delhi",
}
)
# Batched inputs.
preprocessor(
{
"prompts": [
"What is the capital of India?",
"What is the capital of Spain?"
],
"responses": ["New Delhi", "Madrid"],
}
)
# Apply preprocessing to a [`tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset).
features = {
"prompts": [
"What is the capital of India?",
"What is the capital of Spain?"
],
"responses": ["New Delhi", "Madrid"],
}
ds = tf.data.Dataset.from_tensor_slices(features)
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
# Prepare tokens for generation (no end token).
preprocessor.generate_preprocess(["The quick brown fox jumped."])
# Map generation outputs back to strings.
preprocessor.generate_postprocess({
'token_ids': np.array([[2, 818, 3823, 8864, 37423, 32694, 236761, 0]]),
'padding_mask': np.array([[ 1, 1, 1, 1, 1, 1, 1, 0]]),
})
# === Vision and Language Gemma3 model ===
# Load the preprocessor from a preset.
preprocessor = keras_hub.models.Gemma3CausalLMPreprocessor.from_preset(
"gemma3_instruct_4b"
)
# text-only inputs (unbatched)
preprocessor(
{
"prompts": "What is the capital of India?",
"responses": "New Delhi",
}
)
# text-only inputs (batched)
preprocessor(
{
"prompts": [
"What is the capital of India?",
"What is the capital of Spain?"
],
"responses": ["New Delhi", "Madrid"],
}
)
# Unbatched inputs, with one image.
preprocessor(
{
"prompts": "this is a lily <start_of_image>",
"responses": "pristine!",
"images": np.ones((896, 896, 3), dtype="float32")
}
)
# Unbatched inputs, with two images.
preprocessor(
{
"prompts": "lily: <start_of_image>, sunflower: <start_of_image>",
"responses": "pristine!",
"images": [
np.ones((896, 896, 3), dtype="float32"),
np.ones((896, 896, 3), dtype="float32")
],
}
)
# Batched inputs, one image per prompt.
preprocessor(
{
"prompts": [
"this is a lily: <start_of_image>",
"this is a sunflower: <start_of_image>"
],
"responses": ["pristine!", "radiant!"],
"images": [
np.ones((896, 896, 3), dtype="float32"),
np.ones((896, 896, 3), dtype="float32")
]
}
)
# Can also be written this way.
preprocessor(
{
"prompts": [
"this is a lily: <start_of_image>",
"this is a sunflower: <start_of_image>"
],
"responses": ["pristine!", "radiant!"],
"images": [
[np.ones((896, 896, 3), dtype="float32")],
[np.ones((896, 896, 3), dtype="float32")]
]
}
)
# Different number of images in every sample.
preprocessor(
{
"prompts": [
"Who is this singer: <start_of_image>?",
"Who are these musicians <start_of_image>, <start_of_image>?"
],
"responses": ["Arijit Singh", "John Lennon, Paul Mccartney"],
"images": [
[
np.ones((896, 896, 3), dtype="float32"),
np.ones((896, 896, 3), dtype="float32")
],
[np.ones((896, 896, 3), dtype="float32")]
]
}
)
# Apply preprocessing to a [`tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset).
inputs = {
"prompts": [
"Who are these two: <start_of_image>, <start_of_image>",
"Who is this: <start_of_image>?",
"What is the capital of India?"
],
"responses": [
"John Lennon, Paul Mccartney",
"Arijit Singh",
"New Delhi"
],
"images": (
tf.ragged.constant(
[
[np.ones((10, 10, 3)), np.ones((10, 10, 3))],
[np.ones((10, 10, 3))],
[],
]
)
)
}
ds = tf.data.Dataset.from_tensor_slices(inputs)
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
from_preset
methodGemma3CausalLMPreprocessor.from_preset(
preset, config_file="preprocessor.json", **kwargs
)
Instantiate a keras_hub.models.Preprocessor
from a model preset.
A preset is a directory of configs, weights and other file assets used
to save and load a pre-trained model. The preset
can be passed as
one of:
'bert_base_en'
'kaggle://user/bert/keras/bert_base_en'
'hf://user/bert_base_en'
'./bert_base_en'
For any Preprocessor
subclass, you can run cls.presets.keys()
to
list all built-in presets available on the class.
As there are usually multiple preprocessing classes for a given model,
this method should be called on a specific subclass like
keras_hub.models.BertTextClassifierPreprocessor.from_preset()
.
Arguments
Examples
# Load a preprocessor for Gemma generation.
preprocessor = keras_hub.models.CausalLMPreprocessor.from_preset(
"gemma_2b_en",
)
# Load a preprocessor for Bert classification.
preprocessor = keras_hub.models.TextClassifierPreprocessor.from_preset(
"bert_base_en",
)
Preset | Parameters | Description |
---|---|---|
gemma3_1b | 999.89M | 1 billion parameter, 26-layer, text-only pretrained Gemma3 model. |
gemma3_instruct_1b | 999.89M | 1 billion parameter, 26-layer, text-only instruction-tuned Gemma3 model. |
gemma3_4b_text | 3.88B | 4 billion parameter, 34-layer, text-only pretrained Gemma3 model. |
gemma3_instruct_4b_text | 3.88B | 4 billion parameter, 34-layer, text-only instruction-tuned Gemma3 model. |
gemma3_4b | 4.30B | 4 billion parameter, 34-layer, vision+text pretrained Gemma3 model. |
gemma3_instruct_4b | 4.30B | 4 billion parameter, 34-layer, vision+text instruction-tuned Gemma3 model. |
gemma3_12b_text | 11.77B | 12 billion parameter, 48-layer, text-only pretrained Gemma3 model. |
gemma3_instruct_12b_text | 11.77B | 12 billion parameter, 48-layer, text-only instruction-tuned Gemma3 model. |
gemma3_12b | 12.19B | 12 billion parameter, 48-layer, vision+text pretrained Gemma3 model. |
gemma3_instruct_12b | 12.19B | 12 billion parameter, 48-layer, vision+text instruction-tuned Gemma3 model. |
gemma3_27b_text | 27.01B | 27 billion parameter, 62-layer, text-only pretrained Gemma3 model. |
gemma3_instruct_27b_text | 27.01B | 27 billion parameter, 62-layer, text-only instruction-tuned Gemma3 model. |
gemma3_27b | 27.43B | 27 billion parameter, 62-layer, vision+text pretrained Gemma3 model. |
gemma3_instruct_27b | 27.43B | 27 billion parameter, 62-layer, vision+text instruction-tuned Gemma3 model. |
tokenizer
propertykeras_hub.models.Gemma3CausalLMPreprocessor.tokenizer
The tokenizer used to tokenize strings.