KerasNLP contains end-to-end implementations of popular model architectures. These models can be created in two ways:
from_preset()
constructor, which instantiates an object with
a pre-trained configurations, vocabularies, and (optionally) weights.Below, we list all presets available in the library. For more detailed usage, browse the docstring for a particular class. For an in depth introduction to our API, see the getting started guide.
The following preset names correspond to a config and weights for a pretrained
model. Any task, preprocessor, backbone or tokenizer from_preset()
can be used
to create a model from the saved preset.
backbone = keras_nlp.models.Backbone.from_preset("bert_base_en")
tokenizer = keras_nlp.models.Tokenizer.from_preset("bert_base_en")
classifier = keras_nlp.models.TextClassifier.from_preset("bert_base_en", num_classes=2)
preprocessor = keras_nlp.models.TextClassifierPreprocessor.from_preset("bert_base_en")
Preset name | Model | Parameters | Description |
---|---|---|---|
albert_base_en_uncased | ALBERT | 11.68M | 12-layer ALBERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus. Model Card |
albert_large_en_uncased | ALBERT | 17.68M | 24-layer ALBERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus. Model Card |
albert_extra_large_en_uncased | ALBERT | 58.72M | 24-layer ALBERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus. Model Card |
albert_extra_extra_large_en_uncased | ALBERT | 222.60M | 12-layer ALBERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus. Model Card |
bart_base_en | BART | 139.42M | 6-layer BART model where case is maintained. Trained on BookCorpus, English Wikipedia and CommonCrawl. Model Card |
bart_large_en | BART | 406.29M | 12-layer BART model where case is maintained. Trained on BookCorpus, English Wikipedia and CommonCrawl. Model Card |
bart_large_en_cnn | BART | 406.29M | The bart_large_en backbone model fine-tuned on the CNN+DM summarization dataset. Model Card |
bert_tiny_en_uncased | BERT | 4.39M | 2-layer BERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus. Model Card |
bert_small_en_uncased | BERT | 28.76M | 4-layer BERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus. Model Card |
bert_medium_en_uncased | BERT | 41.37M | 8-layer BERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus. Model Card |
bert_base_en_uncased | BERT | 109.48M | 12-layer BERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus. Model Card |
bert_base_en | BERT | 108.31M | 12-layer BERT model where case is maintained. Trained on English Wikipedia + BooksCorpus. Model Card |
bert_base_zh | BERT | 102.27M | 12-layer BERT model. Trained on Chinese Wikipedia. Model Card |
bert_base_multi | BERT | 177.85M | 12-layer BERT model where case is maintained. Trained on trained on Wikipedias of 104 languages Model Card |
bert_large_en_uncased | BERT | 335.14M | 24-layer BERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus. Model Card |
bert_large_en | BERT | 333.58M | 24-layer BERT model where case is maintained. Trained on English Wikipedia + BooksCorpus. Model Card |
bert_tiny_en_uncased_sst2 | BERT | 4.39M | The bert_tiny_en_uncased backbone model fine-tuned on the SST-2 sentiment analysis dataset. Model Card |
bloom_560m_multi | BLOOM | 559.21M | 24-layer Bloom model with hidden dimension of 1024. trained on 45 natural languages and 12 programming languages. Model Card |
bloom_1.1b_multi | BLOOM | 1.07B | 24-layer Bloom model with hidden dimension of 1536. trained on 45 natural languages and 12 programming languages. Model Card |
bloom_1.7b_multi | BLOOM | 1.72B | 24-layer Bloom model with hidden dimension of 2048. trained on 45 natural languages and 12 programming languages. Model Card |
bloom_3b_multi | BLOOM | 3.00B | 30-layer Bloom model with hidden dimension of 2560. trained on 45 natural languages and 12 programming languages. Model Card |
bloomz_560m_multi | BLOOMZ | 559.21M | 24-layer Bloom model with hidden dimension of 1024. finetuned on crosslingual task mixture (xP3) dataset. Model Card |
bloomz_1.1b_multi | BLOOMZ | 1.07B | 24-layer Bloom model with hidden dimension of 1536. finetuned on crosslingual task mixture (xP3) dataset. Model Card |
bloomz_1.7b_multi | BLOOMZ | 1.72B | 24-layer Bloom model with hidden dimension of 2048. finetuned on crosslingual task mixture (xP3) dataset. Model Card |
bloomz_3b_multi | BLOOMZ | 3.00B | 30-layer Bloom model with hidden dimension of 2560. finetuned on crosslingual task mixture (xP3) dataset. Model Card |
deberta_v3_extra_small_en | DeBERTaV3 | 70.68M | 12-layer DeBERTaV3 model where case is maintained. Trained on English Wikipedia, BookCorpus and OpenWebText. Model Card |
deberta_v3_small_en | DeBERTaV3 | 141.30M | 6-layer DeBERTaV3 model where case is maintained. Trained on English Wikipedia, BookCorpus and OpenWebText. Model Card |
deberta_v3_base_en | DeBERTaV3 | 183.83M | 12-layer DeBERTaV3 model where case is maintained. Trained on English Wikipedia, BookCorpus and OpenWebText. Model Card |
deberta_v3_large_en | DeBERTaV3 | 434.01M | 24-layer DeBERTaV3 model where case is maintained. Trained on English Wikipedia, BookCorpus and OpenWebText. Model Card |
deberta_v3_base_multi | DeBERTaV3 | 278.22M | 12-layer DeBERTaV3 model where case is maintained. Trained on the 2.5TB multilingual CC100 dataset. Model Card |
distil_bert_base_en_uncased | DistilBERT | 66.36M | 6-layer DistilBERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus using BERT as the teacher model. Model Card |
distil_bert_base_en | DistilBERT | 65.19M | 6-layer DistilBERT model where case is maintained. Trained on English Wikipedia + BooksCorpus using BERT as the teacher model. Model Card |
distil_bert_base_multi | DistilBERT | 134.73M | 6-layer DistilBERT model where case is maintained. Trained on Wikipedias of 104 languages Model Card |
electra_small_discriminator_uncased_en | ELECTRA | 13.55M | 12-layer small ELECTRA discriminator model. All inputs are lowercased. Trained on English Wikipedia + BooksCorpus. Model Card |
electra_small_generator_uncased_en | ELECTRA | 13.55M | 12-layer small ELECTRA generator model. All inputs are lowercased. Trained on English Wikipedia + BooksCorpus. Model Card |
electra_base_discriminator_uncased_en | ELECTRA | 109.48M | 12-layer base ELECTRA discriminator model. All inputs are lowercased. Trained on English Wikipedia + BooksCorpus. Model Card |
electra_base_generator_uncased_en | ELECTRA | 33.58M | 12-layer base ELECTRA generator model. All inputs are lowercased. Trained on English Wikipedia + BooksCorpus. Model Card |
electra_large_discriminator_uncased_en | ELECTRA | 335.14M | 24-layer large ELECTRA discriminator model. All inputs are lowercased. Trained on English Wikipedia + BooksCorpus. Model Card |
electra_large_generator_uncased_en | ELECTRA | 51.07M | 24-layer large ELECTRA generator model. All inputs are lowercased. Trained on English Wikipedia + BooksCorpus. Model Card |
f_net_base_en | FNet | 82.86M | 12-layer FNet model where case is maintained. Trained on the C4 dataset. Model Card |
f_net_large_en | FNet | 236.95M | 24-layer FNet model where case is maintained. Trained on the C4 dataset. Model Card |
falcon_refinedweb_1b_en | Falcon | 1.31B | 24-layer Falcon model (Falcon with 1B parameters), trained on 350B tokens of RefinedWeb dataset. Model Card |
gemma_2b_en | Gemma | 2.51B | 2 billion parameter, 18-layer, base Gemma model. Model Card |
gemma_instruct_2b_en | Gemma | 2.51B | 2 billion parameter, 18-layer, instruction tuned Gemma model. Model Card |
gemma_1.1_instruct_2b_en | Gemma | 2.51B | 2 billion parameter, 18-layer, instruction tuned Gemma model. The 1.1 update improves model quality. Model Card |
code_gemma_1.1_2b_en | Gemma | 2.51B | 2 billion parameter, 18-layer, CodeGemma model. This model has been trained on a fill-in-the-middle (FIM) task for code completion. The 1.1 update improves model quality. Model Card |
code_gemma_2b_en | Gemma | 2.51B | 2 billion parameter, 18-layer, CodeGemma model. This model has been trained on a fill-in-the-middle (FIM) task for code completion. Model Card |
gemma_7b_en | Gemma | 8.54B | 7 billion parameter, 28-layer, base Gemma model. Model Card |
gemma_instruct_7b_en | Gemma | 8.54B | 7 billion parameter, 28-layer, instruction tuned Gemma model. Model Card |
gemma_1.1_instruct_7b_en | Gemma | 8.54B | 7 billion parameter, 28-layer, instruction tuned Gemma model. The 1.1 update improves model quality. Model Card |
code_gemma_7b_en | Gemma | 8.54B | 7 billion parameter, 28-layer, CodeGemma model. This model has been trained on a fill-in-the-middle (FIM) task for code completion. Model Card |
code_gemma_instruct_7b_en | Gemma | 8.54B | 7 billion parameter, 28-layer, instruction tuned CodeGemma model. This model has been trained for chat use cases related to code. Model Card |
code_gemma_1.1_instruct_7b_en | Gemma | 8.54B | 7 billion parameter, 28-layer, instruction tuned CodeGemma model. This model has been trained for chat use cases related to code. The 1.1 update improves model quality. Model Card |
gemma2_2b_en | Gemma | 2.61B | 2 billion parameter, 26-layer, base Gemma model. Model Card |
gemma2_instruct_2b_en | Gemma | 2.61B | 2 billion parameter, 26-layer, instruction tuned Gemma model. Model Card |
gemma2_9b_en | Gemma | 9.24B | 9 billion parameter, 42-layer, base Gemma model. Model Card |
gemma2_instruct_9b_en | Gemma | 9.24B | 9 billion parameter, 42-layer, instruction tuned Gemma model. Model Card |
gemma2_27b_en | Gemma | 27.23B | 27 billion parameter, 42-layer, base Gemma model. Model Card |
gemma2_instruct_27b_en | Gemma | 27.23B | 27 billion parameter, 42-layer, instruction tuned Gemma model. Model Card |
shieldgemma_2b_en | Gemma | 2.61B | 2 billion parameter, 26-layer, ShieldGemma model. Model Card |
shieldgemma_9b_en | Gemma | 9.24B | 9 billion parameter, 42-layer, ShieldGemma model. Model Card |
shieldgemma_27b_en | Gemma | 27.23B | 27 billion parameter, 42-layer, ShieldGemma model. Model Card |
gpt2_base_en | GPT-2 | 124.44M | 12-layer GPT-2 model where case is maintained. Trained on WebText. Model Card |
gpt2_medium_en | GPT-2 | 354.82M | 24-layer GPT-2 model where case is maintained. Trained on WebText. Model Card |
gpt2_large_en | GPT-2 | 774.03M | 36-layer GPT-2 model where case is maintained. Trained on WebText. Model Card |
gpt2_extra_large_en | GPT-2 | 1.56B | 48-layer GPT-2 model where case is maintained. Trained on WebText. Model Card |
gpt2_base_en_cnn_dailymail | GPT-2 | 124.44M | 12-layer GPT-2 model where case is maintained. Finetuned on the CNN/DailyMail summarization dataset. |
llama3_8b_en | LLaMA 3 | 8.03B | 8 billion parameter, 32-layer, base LLaMA 3 model. Model Card |
llama3_8b_en_int8 | LLaMA 3 | 8.03B | 8 billion parameter, 32-layer, base LLaMA 3 model with activation and weights quantized to int8. Model Card |
llama3_instruct_8b_en | LLaMA 3 | 8.03B | 8 billion parameter, 32-layer, instruction tuned LLaMA 3 model. Model Card |
llama3_instruct_8b_en_int8 | LLaMA 3 | 8.03B | 8 billion parameter, 32-layer, instruction tuned LLaMA 3 model with activation and weights quantized to int8. Model Card |
llama2_7b_en | LLaMA 2 | 6.74B | 7 billion parameter, 32-layer, base LLaMA 2 model. Model Card |
llama2_7b_en_int8 | LLaMA 2 | 6.74B | 7 billion parameter, 32-layer, base LLaMA 2 model with activation and weights quantized to int8. Model Card |
llama2_instruct_7b_en | LLaMA 2 | 6.74B | 7 billion parameter, 32-layer, instruction tuned LLaMA 2 model. Model Card |
llama2_instruct_7b_en_int8 | LLaMA 2 | 6.74B | 7 billion parameter, 32-layer, instruction tuned LLaMA 2 model with activation and weights quantized to int8. Model Card |
vicuna_1.5_7b_en | Vicuna | 6.74B | 7 billion parameter, 32-layer, instruction tuned Vicuna v1.5 model. Model Card |
mistral_7b_en | Mistral | 7.24B | Mistral 7B base model Model Card |
mistral_instruct_7b_en | Mistral | 7.24B | Mistral 7B instruct model Model Card |
mistral_0.2_instruct_7b_en | Mistral | 7.24B | Mistral 7B instruct Version 0.2 model Model Card |
opt_125m_en | OPT | 125.24M | 12-layer OPT model where case in maintained. Trained on BookCorpus, CommonCrawl, Pile, and PushShift.io corpora. Model Card |
opt_1.3b_en | OPT | 1.32B | 24-layer OPT model where case in maintained. Trained on BookCorpus, CommonCrawl, Pile, and PushShift.io corpora. Model Card |
opt_2.7b_en | OPT | 2.70B | 32-layer OPT model where case in maintained. Trained on BookCorpus, CommonCrawl, Pile, and PushShift.io corpora. Model Card |
opt_6.7b_en | OPT | 6.70B | 32-layer OPT model where case in maintained. Trained on BookCorpus, CommonCrawl, Pile, and PushShift.io corpora. Model Card |
pali_gemma_3b_mix_224 | PaliGemma | 2.92B | image size 224, mix fine tuned, text sequence length is 256 Model Card |
pali_gemma_3b_mix_448 | PaliGemma | 2.92B | image size 448, mix fine tuned, text sequence length is 512 Model Card |
pali_gemma_3b_224 | PaliGemma | 2.92B | image size 224, pre trained, text sequence length is 128 Model Card |
pali_gemma_3b_448 | PaliGemma | 2.92B | image size 448, pre trained, text sequence length is 512 Model Card |
pali_gemma_3b_896 | PaliGemma | 2.93B | image size 896, pre trained, text sequence length is 512 Model Card |
phi3_mini_4k_instruct_en | Phi-3 | 3.82B | 3.8 billion parameters, 32 layers, 4k context length, Phi-3 model. The model was trained using the Phi-3 datasets. This dataset includes both synthetic data and filtered publicly available website data, with an emphasis on high-quality and reasoning-dense properties. Model Card |
phi3_mini_128k_instruct_en | Phi-3 | 3.82B | 3.8 billion parameters, 32 layers, 128k context length, Phi-3 model. The model was trained using the Phi-3 datasets. This dataset includes both synthetic data and filtered publicly available website data, with an emphasis on high-quality and reasoning-dense properties. Model Card |
roberta_base_en | RoBERTa | 124.05M | 12-layer RoBERTa model where case is maintained.Trained on English Wikipedia, BooksCorpus, CommonCraw, and OpenWebText. Model Card |
roberta_large_en | RoBERTa | 354.31M | 24-layer RoBERTa model where case is maintained.Trained on English Wikipedia, BooksCorpus, CommonCraw, and OpenWebText. Model Card |
t5_small_multi | T5 | 0 | 8-layer T5 model. Trained on the Colossal Clean Crawled Corpus (C4). Model Card |
t5_base_multi | T5 | 0 | 12-layer T5 model. Trained on the Colossal Clean Crawled Corpus (C4). Model Card |
t5_large_multi | T5 | 0 | 24-layer T5 model. Trained on the Colossal Clean Crawled Corpus (C4). Model Card |
flan_small_multi | T5 | 0 | 8-layer T5 model. Trained on the Colossal Clean Crawled Corpus (C4). Model Card |
flan_base_multi | T5 | 0 | 12-layer T5 model. Trained on the Colossal Clean Crawled Corpus (C4). Model Card |
flan_large_multi | T5 | 0 | 24-layer T5 model. Trained on the Colossal Clean Crawled Corpus (C4). Model Card |
whisper_tiny_en | Whisper | 37.18M | 4-layer Whisper model. Trained on 438,000 hours of labelled English speech data. Model Card |
whisper_base_en | Whisper | 124.44M | 6-layer Whisper model. Trained on 438,000 hours of labelled English speech data. Model Card |
whisper_small_en | Whisper | 241.73M | 12-layer Whisper model. Trained on 438,000 hours of labelled English speech data. Model Card |
whisper_medium_en | Whisper | 763.86M | 24-layer Whisper model. Trained on 438,000 hours of labelled English speech data. Model Card |
whisper_tiny_multi | Whisper | 37.76M | 4-layer Whisper model. Trained on 680,000 hours of labelled multilingual speech data. Model Card |
whisper_base_multi | Whisper | 72.59M | 6-layer Whisper model. Trained on 680,000 hours of labelled multilingual speech data. Model Card |
whisper_small_multi | Whisper | 241.73M | 12-layer Whisper model. Trained on 680,000 hours of labelled multilingual speech data. Model Card |
whisper_medium_multi | Whisper | 763.86M | 24-layer Whisper model. Trained on 680,000 hours of labelled multilingual speech data. Model Card |
whisper_large_multi | Whisper | 1.54B | 32-layer Whisper model. Trained on 680,000 hours of labelled multilingual speech data. Model Card |
whisper_large_multi_v2 | Whisper | 1.54B | 32-layer Whisper model. Trained for 2.5 epochs on 680,000 hours of labelled multilingual speech data. An improved of whisper_large_multi . Model Card |
xlm_roberta_base_multi | XLM-RoBERTa | 277.45M | 12-layer XLM-RoBERTa model where case is maintained. Trained on CommonCrawl in 100 languages. Model Card |
xlm_roberta_large_multi | XLM-RoBERTa | 558.84M | 24-layer XLM-RoBERTa model where case is maintained. Trained on CommonCrawl in 100 languages. Model Card |
Note: The links provided will lead to the model card or to the official README, if no model card has been provided by the author.