TextClassifier

[source]

TextClassifier class

keras_hub.models.TextClassifier(*args, compile=True, **kwargs)

Base class for all classification tasks.

TextClassifier tasks wrap a keras_hub.models.Backbone and a keras_hub.models.Preprocessor to create a model that can be used for sequence classification. TextClassifier tasks take an additional num_classes argument, controlling the number of predicted output classes.

To fine-tune with fit(), pass a dataset containing tuples of (x, y) labels where x is a string and y is a integer from [0, num_classes).

All TextClassifier tasks include a from_preset() constructor which can be used to load a pre-trained config and weights.

Some, but not all, classification presets include classification head weights in a task.weights.h5 file. For these presets, you can omit passing num_classes to restore the saved classification head. For all presets, if num_classes is passed as a kwarg to from_preset(), the classification head will be randomly initialized.

Example

# Load a BERT classifier with pre-trained weights.
classifier = keras_hub.models.TextClassifier.from_preset(
    "bert_base_en",
    num_classes=2,
)
# Fine-tune on IMDb movie reviews (or any dataset).
imdb_train, imdb_test = tfds.load(
    "imdb_reviews",
    split=["train", "test"],
    as_supervised=True,
    batch_size=16,
)
classifier.fit(imdb_train, validation_data=imdb_test)
# Predict two new examples.
classifier.predict(["What an amazing movie!", "A total waste of my time."])

[source]

from_preset method

TextClassifier.from_preset(preset, load_weights=True, **kwargs)

Instantiate a keras_hub.models.Task from a model preset.

A preset is a directory of configs, weights and other file assets used to save and load a pre-trained model. The preset can be passed as one of:

  1. a built-in preset identifier like 'bert_base_en'
  2. a Kaggle Models handle like 'kaggle://user/bert/keras/bert_base_en'
  3. a Hugging Face handle like 'hf://user/bert_base_en'
  4. a path to a local preset directory like './bert_base_en'

For any Task subclass, you can run cls.presets.keys() to list all built-in presets available on the class.

This constructor can be called in one of two ways. Either from a task specific base class like keras_hub.models.CausalLM.from_preset(), or from a model class like keras_hub.models.BertTextClassifier.from_preset(). If calling from the a base class, the subclass of the returning object will be inferred from the config in the preset directory.

Arguments

  • preset: string. A built-in preset identifier, a Kaggle Models handle, a Hugging Face handle, or a path to a local directory.
  • load_weights: bool. If True, saved weights will be loaded into the model architecture. If False, all weights will be randomly initialized.

Examples

# Load a Gemma generative task.
causal_lm = keras_hub.models.CausalLM.from_preset(
    "gemma_2b_en",
)

# Load a Bert classification task.
model = keras_hub.models.TextClassifier.from_preset(
    "bert_base_en",
    num_classes=2,
)
Preset Parameters Description
albert_base_en_uncased 11.68M 12-layer ALBERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus.
albert_large_en_uncased 17.68M 24-layer ALBERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus.
albert_extra_large_en_uncased 58.72M 24-layer ALBERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus.
albert_extra_extra_large_en_uncased 222.60M 12-layer ALBERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus.
bert_tiny_en_uncased 4.39M 2-layer BERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus.
bert_tiny_en_uncased_sst2 4.39M The bert_tiny_en_uncased backbone model fine-tuned on the SST-2 sentiment analysis dataset.
bert_small_en_uncased 28.76M 4-layer BERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus.
bert_medium_en_uncased 41.37M 8-layer BERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus.
bert_base_zh 102.27M 12-layer BERT model. Trained on Chinese Wikipedia.
bert_base_en 108.31M 12-layer BERT model where case is maintained. Trained on English Wikipedia + BooksCorpus.
bert_base_en_uncased 109.48M 12-layer BERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus.
bert_base_multi 177.85M 12-layer BERT model where case is maintained. Trained on trained on Wikipedias of 104 languages
bert_large_en 333.58M 24-layer BERT model where case is maintained. Trained on English Wikipedia + BooksCorpus.
bert_large_en_uncased 335.14M 24-layer BERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus.
deberta_v3_extra_small_en 70.68M 12-layer DeBERTaV3 model where case is maintained. Trained on English Wikipedia, BookCorpus and OpenWebText.
deberta_v3_small_en 141.30M 6-layer DeBERTaV3 model where case is maintained. Trained on English Wikipedia, BookCorpus and OpenWebText.
deberta_v3_base_en 183.83M 12-layer DeBERTaV3 model where case is maintained. Trained on English Wikipedia, BookCorpus and OpenWebText.
deberta_v3_base_multi 278.22M 12-layer DeBERTaV3 model where case is maintained. Trained on the 2.5TB multilingual CC100 dataset.
deberta_v3_large_en 434.01M 24-layer DeBERTaV3 model where case is maintained. Trained on English Wikipedia, BookCorpus and OpenWebText.
distil_bert_base_en 65.19M 6-layer DistilBERT model where case is maintained. Trained on English Wikipedia + BooksCorpus using BERT as the teacher model.
distil_bert_base_en_uncased 66.36M 6-layer DistilBERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus using BERT as the teacher model.
distil_bert_base_multi 134.73M 6-layer DistilBERT model where case is maintained. Trained on Wikipedias of 104 languages
f_net_base_en 82.86M 12-layer FNet model where case is maintained. Trained on the C4 dataset.
f_net_large_en 236.95M 24-layer FNet model where case is maintained. Trained on the C4 dataset.
roberta_base_en 124.05M 12-layer RoBERTa model where case is maintained.Trained on English Wikipedia, BooksCorpus, CommonCraw, and OpenWebText.
roberta_large_en 354.31M 24-layer RoBERTa model where case is maintained.Trained on English Wikipedia, BooksCorpus, CommonCraw, and OpenWebText.
xlm_roberta_base_multi 277.45M 12-layer XLM-RoBERTa model where case is maintained. Trained on CommonCrawl in 100 languages.
xlm_roberta_large_multi 558.84M 24-layer XLM-RoBERTa model where case is maintained. Trained on CommonCrawl in 100 languages.

[source]

compile method

TextClassifier.compile(optimizer="auto", loss="auto", metrics="auto", **kwargs)

Configures the TextClassifier task for training.

The TextClassifier task extends the default compilation signature of keras.Model.compile with defaults for optimizer, loss, and metrics. To override these defaults, pass any value to these arguments during compilation.

Arguments

  • optimizer: "auto", an optimizer name, or a keras.Optimizer instance. Defaults to "auto", which uses the default optimizer for the given model and task. See keras.Model.compile and keras.optimizers for more info on possible optimizer values.
  • loss: "auto", a loss name, or a keras.losses.Loss instance. Defaults to "auto", where a keras.losses.SparseCategoricalCrossentropy loss will be applied for the classification task. See keras.Model.compile and keras.losses for more info on possible loss values.
  • metrics: "auto", or a list of metrics to be evaluated by the model during training and testing. Defaults to "auto", where a keras.metrics.SparseCategoricalAccuracy will be applied to track the accuracy of the model during training. See keras.Model.compile and keras.metrics for more info on possible metrics values.
  • **kwargs: See keras.Model.compile for a full list of arguments supported by the compile method.

[source]

save_to_preset method

TextClassifier.save_to_preset(preset_dir)

Save task to a preset directory.

Arguments

  • preset_dir: The path to the local model preset directory.

preprocessor property

keras_hub.models.TextClassifier.preprocessor

A keras_hub.models.Preprocessor layer used to preprocess input.


backbone property

keras_hub.models.TextClassifier.backbone

A keras_hub.models.Backbone model with the core architecture.