CausalLM

[source]

CausalLM class

keras_hub.models.CausalLM()

Base class for generative language modeling tasks.

CausalLM tasks wrap a keras_hub.models.Backbone and a keras_hub.models.Preprocessor to create a model that can be used for generation and generative fine-tuning.

CausalLM tasks provide an additional, high-level generate() function which can be used to auto-regressively sample a model token by token with a string in, string out signature. The compile() method of all CausalLM classes contains an additional sampler argument, which can be used to pass a keras_hub.samplers.Sampler to control how the predicted distribution will be sampled.

When calling fit(), the tokenized input will be predicted token-by-token with a causal mask applied, which gives both a pre-training and supervised fine-tuning setup for controlling inference-time generation.

All CausalLM tasks include a from_preset() constructor which can be used to load a pre-trained config and weights.

Example

# Load a GPT2 backbone with pre-trained weights.
causal_lm = keras_hub.models.CausalLM.from_preset(
    "gpt2_base_en",
)
causal_lm.compile(sampler="top_k")
causal_lm.generate("Keras is a", max_length=64)

# Load a Mistral instruction tuned checkpoint at bfloat16 precision.
causal_lm = keras_hub.models.CausalLM.from_preset(
    "mistral_instruct_7b_en",
    dtype="bfloat16",
)
causal_lm.compile(sampler="greedy")
causal_lm.generate("Keras is a", max_length=64)

[source]

from_preset method

CausalLM.from_preset(preset, load_weights=True, **kwargs)

Instantiate a keras_hub.models.Task from a model preset.

A preset is a directory of configs, weights and other file assets used to save and load a pre-trained model. The preset can be passed as one of:

  1. a built-in preset identifier like 'bert_base_en'
  2. a Kaggle Models handle like 'kaggle://user/bert/keras/bert_base_en'
  3. a Hugging Face handle like 'hf://user/bert_base_en'
  4. a path to a local preset directory like './bert_base_en'

For any Task subclass, you can run cls.presets.keys() to list all built-in presets available on the class.

This constructor can be called in one of two ways. Either from a task specific base class like keras_hub.models.CausalLM.from_preset(), or from a model class like keras_hub.models.BertTextClassifier.from_preset(). If calling from the a base class, the subclass of the returning object will be inferred from the config in the preset directory.

Arguments

  • preset: string. A built-in preset identifier, a Kaggle Models handle, a Hugging Face handle, or a path to a local directory.
  • load_weights: bool. If True, saved weights will be loaded into the model architecture. If False, all weights will be randomly initialized.

Examples

# Load a Gemma generative task.
causal_lm = keras_hub.models.CausalLM.from_preset(
    "gemma_2b_en",
)

# Load a Bert classification task.
model = keras_hub.models.TextClassifier.from_preset(
    "bert_base_en",
    num_classes=2,
)
Preset Parameters Description
bart_base_en 139.42M 6-layer BART model where case is maintained. Trained on BookCorpus, English Wikipedia and CommonCrawl.
bart_large_en 406.29M 12-layer BART model where case is maintained. Trained on BookCorpus, English Wikipedia and CommonCrawl.
bart_large_en_cnn 406.29M The bart_large_en backbone model fine-tuned on the CNN+DM summarization dataset.
bloom_560m_multi 559.21M 24-layer Bloom model with hidden dimension of 1024. trained on 45 natural languages and 12 programming languages.
bloomz_560m_multi 559.21M 24-layer Bloom model with hidden dimension of 1024. finetuned on crosslingual task mixture (xP3) dataset.
bloom_1.1b_multi 1.07B 24-layer Bloom model with hidden dimension of 1536. trained on 45 natural languages and 12 programming languages.
bloomz_1.1b_multi 1.07B 24-layer Bloom model with hidden dimension of 1536. finetuned on crosslingual task mixture (xP3) dataset.
bloom_1.7b_multi 1.72B 24-layer Bloom model with hidden dimension of 2048. trained on 45 natural languages and 12 programming languages.
bloomz_1.7b_multi 1.72B 24-layer Bloom model with hidden dimension of 2048. finetuned on crosslingual task mixture (xP3) dataset.
bloom_3b_multi 3.00B 30-layer Bloom model with hidden dimension of 2560. trained on 45 natural languages and 12 programming languages.
bloomz_3b_multi 3.00B 30-layer Bloom model with hidden dimension of 2560. finetuned on crosslingual task mixture (xP3) dataset.
falcon_refinedweb_1b_en 1.31B 24-layer Falcon model (Falcon with 1B parameters), trained on 350B tokens of RefinedWeb dataset.
gemma_2b_en 2.51B 2 billion parameter, 18-layer, base Gemma model.
gemma_instruct_2b_en 2.51B 2 billion parameter, 18-layer, instruction tuned Gemma model.
gemma_1.1_instruct_2b_en 2.51B 2 billion parameter, 18-layer, instruction tuned Gemma model. The 1.1 update improves model quality.
code_gemma_1.1_2b_en 2.51B 2 billion parameter, 18-layer, CodeGemma model. This model has been trained on a fill-in-the-middle (FIM) task for code completion. The 1.1 update improves model quality.
code_gemma_2b_en 2.51B 2 billion parameter, 18-layer, CodeGemma model. This model has been trained on a fill-in-the-middle (FIM) task for code completion.
gemma2_2b_en 2.61B 2 billion parameter, 26-layer, base Gemma model.
gemma2_instruct_2b_en 2.61B 2 billion parameter, 26-layer, instruction tuned Gemma model.
shieldgemma_2b_en 2.61B 2 billion parameter, 26-layer, ShieldGemma model.
gemma_7b_en 8.54B 7 billion parameter, 28-layer, base Gemma model.
gemma_instruct_7b_en 8.54B 7 billion parameter, 28-layer, instruction tuned Gemma model.
gemma_1.1_instruct_7b_en 8.54B 7 billion parameter, 28-layer, instruction tuned Gemma model. The 1.1 update improves model quality.
code_gemma_7b_en 8.54B 7 billion parameter, 28-layer, CodeGemma model. This model has been trained on a fill-in-the-middle (FIM) task for code completion.
code_gemma_instruct_7b_en 8.54B 7 billion parameter, 28-layer, instruction tuned CodeGemma model. This model has been trained for chat use cases related to code.
code_gemma_1.1_instruct_7b_en 8.54B 7 billion parameter, 28-layer, instruction tuned CodeGemma model. This model has been trained for chat use cases related to code. The 1.1 update improves model quality.
gemma2_9b_en 9.24B 9 billion parameter, 42-layer, base Gemma model.
gemma2_instruct_9b_en 9.24B 9 billion parameter, 42-layer, instruction tuned Gemma model.
shieldgemma_9b_en 9.24B 9 billion parameter, 42-layer, ShieldGemma model.
gemma2_27b_en 27.23B 27 billion parameter, 42-layer, base Gemma model.
gemma2_instruct_27b_en 27.23B 27 billion parameter, 42-layer, instruction tuned Gemma model.
shieldgemma_27b_en 27.23B 27 billion parameter, 42-layer, ShieldGemma model.
gpt2_base_en 124.44M 12-layer GPT-2 model where case is maintained. Trained on WebText.
gpt2_base_en_cnn_dailymail 124.44M 12-layer GPT-2 model where case is maintained. Finetuned on the CNN/DailyMail summarization dataset.
gpt2_medium_en 354.82M 24-layer GPT-2 model where case is maintained. Trained on WebText.
gpt2_large_en 774.03M 36-layer GPT-2 model where case is maintained. Trained on WebText.
gpt2_extra_large_en 1.56B 48-layer GPT-2 model where case is maintained. Trained on WebText.
llama2_7b_en 6.74B 7 billion parameter, 32-layer, base LLaMA 2 model.
llama2_instruct_7b_en 6.74B 7 billion parameter, 32-layer, instruction tuned LLaMA 2 model.
vicuna_1.5_7b_en 6.74B 7 billion parameter, 32-layer, instruction tuned Vicuna v1.5 model.
llama2_7b_en_int8 6.74B 7 billion parameter, 32-layer, base LLaMA 2 model with activation and weights quantized to int8.
llama2_instruct_7b_en_int8 6.74B 7 billion parameter, 32-layer, instruction tuned LLaMA 2 model with activation and weights quantized to int8.
llama3_8b_en 8.03B 8 billion parameter, 32-layer, base LLaMA 3 model.
llama3_instruct_8b_en 8.03B 8 billion parameter, 32-layer, instruction tuned LLaMA 3 model.
llama3_8b_en_int8 8.03B 8 billion parameter, 32-layer, base LLaMA 3 model with activation and weights quantized to int8.
llama3_instruct_8b_en_int8 8.03B 8 billion parameter, 32-layer, instruction tuned LLaMA 3 model with activation and weights quantized to int8.
mistral_7b_en 7.24B Mistral 7B base model
mistral_instruct_7b_en 7.24B Mistral 7B instruct model
mistral_0.2_instruct_7b_en 7.24B Mistral 7B instruct Version 0.2 model
opt_125m_en 125.24M 12-layer OPT model where case in maintained. Trained on BookCorpus, CommonCrawl, Pile, and PushShift.io corpora.
opt_1.3b_en 1.32B 24-layer OPT model where case in maintained. Trained on BookCorpus, CommonCrawl, Pile, and PushShift.io corpora.
opt_2.7b_en 2.70B 32-layer OPT model where case in maintained. Trained on BookCorpus, CommonCrawl, Pile, and PushShift.io corpora.
opt_6.7b_en 6.70B 32-layer OPT model where case in maintained. Trained on BookCorpus, CommonCrawl, Pile, and PushShift.io corpora.
pali_gemma_3b_mix_224 2.92B image size 224, mix fine tuned, text sequence length is 256
pali_gemma_3b_224 2.92B image size 224, pre trained, text sequence length is 128
pali_gemma_3b_mix_448 2.92B image size 448, mix fine tuned, text sequence length is 512
pali_gemma_3b_448 2.92B image size 448, pre trained, text sequence length is 512
pali_gemma_3b_896 2.93B image size 896, pre trained, text sequence length is 512
pali_gemma2_pt_3b_224 3.03B 3 billion parameter, image size 224, 27-layer for SigLIP-So400m vision encoder and 26-layer Gemma2 2B lanuage model. This model has been pre-trained on a mixture of datasets.
pali_gemma2_3b_ft_docci_448 3.03B 3 billion parameter, image size 448, 27-layer for SigLIP-So400m vision encoder and 26-layer Gemma2 2B lanuage model. This model has been fine-tuned on the DOCCI dataset for improved descriptions with fine-grained details.
pali_gemma2_pt_3b_448 3.03B 3 billion parameter, image size 448, 27-layer for SigLIP-So400m vision encoder and 26-layer Gemma2 2B lanuage model. This model has been pre-trained on a mixture of datasets.
pali_gemma2_pt_3b_896 3.04B 3 billion parameter, image size 896, 27-layer for SigLIP-So400m vision encoder and 26-layer Gemma2 2B lanuage model. This model has been pre-trained on a mixture of datasets.
pali_gemma2_pt_10b_224 9.66B 10 billion parameter, image size 224, 27-layer for SigLIP-So400m vision encoder and 42-layer Gemma2 9B lanuage model. This model has been pre-trained on a mixture of datasets.
pali_gemma2_pt_28b_224 9.66B 28 billion parameter, image size 224, 27-layer for SigLIP-So400m vision encoder and 46-layer Gemma2 27B lanuage model. This model has been pre-trained on a mixture of datasets.
pali_gemma2_10b_ft_docci_448 9.66B 10 billion parameter, 27-layer for SigLIP-So400m vision encoder and 42-layer Gemma2 9B lanuage model. This model has been fine-tuned on the DOCCI dataset for improved descriptions with fine-grained details.
pali_gemma2_pt_10b_448 9.66B 10 billion parameter, image size 448, 27-layer for SigLIP-So400m vision encoder and 42-layer Gemma2 9B lanuage model. This model has been pre-trained on a mixture of datasets.
pali_gemma2_pt_28b_448 9.66B 28 billion parameter, image size 448, 27-layer for SigLIP-So400m vision encoder and 46-layer Gemma2 27B lanuage model. This model has been pre-trained on a mixture of datasets.
pali_gemma2_pt_10b_896 9.67B 10 billion parameter, image size 896, 27-layer for SigLIP-So400m vision encoder and 42-layer Gemma2 9B lanuage model. This model has been pre-trained on a mixture of datasets.
pali_gemma2_pt_28b_896 9.67B 28 billion parameter, image size 896, 27-layer for SigLIP-So400m vision encoder and 46-layer Gemma2 27B lanuage model. This model has been pre-trained on a mixture of datasets.
phi3_mini_4k_instruct_en 3.82B 3.8 billion parameters, 32 layers, 4k context length, Phi-3 model. The model was trained using the Phi-3 datasets. This dataset includes both synthetic data and filtered publicly available website data, with an emphasis on high-quality and reasoning-dense properties.
phi3_mini_128k_instruct_en 3.82B 3.8 billion parameters, 32 layers, 128k context length, Phi-3 model. The model was trained using the Phi-3 datasets. This dataset includes both synthetic data and filtered publicly available website data, with an emphasis on high-quality and reasoning-dense properties.

[source]

compile method

CausalLM.compile(
    optimizer="auto", loss="auto", weighted_metrics="auto", sampler="top_k", **kwargs
)

Configures the CausalLM task for training and generation.

The CausalLM task extends the default compilation signature of keras.Model.compile with defaults for optimizer, loss, and weighted_metrics. To override these defaults, pass any value to these arguments during compilation.

The CausalLM task adds a new sampler to compile, which can be used to control the sampling strategy used with the generate function.

Note that because training inputs include padded tokens which are excluded from the loss, it is almost always a good idea to compile with weighted_metrics and not metrics.

Arguments

  • optimizer: "auto", an optimizer name, or a keras.Optimizer instance. Defaults to "auto", which uses the default optimizer for the given model and task. See keras.Model.compile and keras.optimizers for more info on possible optimizer values.
  • loss: "auto", a loss name, or a keras.losses.Loss instance. Defaults to "auto", where a keras.losses.SparseCategoricalCrossentropy loss will be applied for the token classification CausalLM task. See keras.Model.compile and keras.losses for more info on possible loss values.
  • weighted_metrics: "auto", or a list of metrics to be evaluated by the model during training and testing. Defaults to "auto", where a keras.metrics.SparseCategoricalAccuracy will be applied to track the accuracy of the model at guessing masked token values. See keras.Model.compile and keras.metrics for more info on possible weighted_metrics values.
  • sampler: A sampler name, or a keras_hub.samplers.Sampler instance. Configures the sampling method used during generate() calls. See keras_hub.samplers for a full list of built-in sampling strategies.
  • **kwargs: See keras.Model.compile for a full list of arguments supported by the compile method.

[source]

generate method

CausalLM.generate(inputs, max_length=None, stop_token_ids="auto", strip_prompt=False)

Generate text given prompt inputs.

This method generates text based on given inputs. The sampling method used for generation can be set via the compile() method.

If inputs are a tf.data.Dataset, outputs will be generated "batch-by-batch" and concatenated. Otherwise, all inputs will be handled as a single batch.

If a preprocessor is attached to the model, inputs will be preprocessed inside the generate() function and should match the structure expected by the preprocessor layer (usually raw strings). If a preprocessor is not attached, inputs should match the structure expected by the backbone. See the example usage above for a demonstration of each.

Arguments

  • inputs: python data, tensor data, or a tf.data.Dataset. If a preprocessor is attached to the model, inputs should match the structure expected by the preprocessor layer. If a preprocessor is not attached, inputs should match the structure expected the backbone model.
  • max_length: Optional. int. The max length of the generated sequence. Will default to the max configured sequence_length of the preprocessor. If preprocessor is None, inputs should be should be padded to the desired maximum length and this argument will be ignored.
  • stop_token_ids: Optional. None, "auto", or tuple of token ids. Defaults to "auto" which uses the preprocessor.tokenizer.end_token_id. Not specifying a processor will produce an error. None stops generation after generating max_length tokens. You may also specify a list of token id's the model should stop on. Note that sequences of tokens will each be interpreted as a stop token, multi-token stop sequences are not supported.
  • strip_prompt: Optional. By default, generate() returns the full prompt followed by its completion generated by the model. If this option is set to True, only the newly generated text is returned.

[source]

save_to_preset method

CausalLM.save_to_preset(preset_dir)

Save task to a preset directory.

Arguments

  • preset_dir: The path to the local model preset directory.

preprocessor property

keras_hub.models.CausalLM.preprocessor

A keras_hub.models.Preprocessor layer used to preprocess input.


backbone property

keras_hub.models.CausalLM.backbone

A keras_hub.models.Backbone model with the core architecture.