ยป Keras API reference / KerasNLP / Models / Bert / BertClassifier model

BertClassifier model

[source]

BertClassifier class

keras_nlp.models.BertClassifier(
    backbone, num_classes, preprocessor=None, activation=None, dropout=0.1, **kwargs
)

An end-to-end BERT model for classification tasks.

This model attaches a classification head to a keras_nlp.model.BertBackbone instance, mapping from the backbone outputs to logits suitable for a classification task. For usage of this model with pre-trained weights, use the from_preset() constructor.

This model can optionally be configured with a preprocessor layer, in which case it will automatically apply preprocessing to raw inputs during fit(), predict(), and evaluate(). This is done by default when creating the model with from_preset().

Disclaimer: Pre-trained models are provided on an "as is" basis, without warranties or conditions of any kind.

Arguments

  • backbone: A keras_nlp.models.BertBackbone instance.
  • num_classes: int. Number of classes to predict.
  • preprocessor: A keras_nlp.models.BertPreprocessor or None. If None, this model will not apply preprocessing, and inputs should be preprocessed before calling the model.
  • activation: Optional str or callable, defaults to None. The activation function to use on the model outputs. Set activation="softmax" to return output probabilities.
  • dropout: float. The dropout probability value, applied after the dense layer.

Examples

Raw string data.

features = ["The quick brown fox jumped.", "I forgot my homework."]
labels = [0, 3]

# Pretrained classifier.
classifier = keras_nlp.models.BertClassifier.from_preset(
    "bert_base_en_uncased",
    num_classes=4,
)
classifier.fit(x=features, y=labels, batch_size=2)
classifier.predict(x=features, batch_size=2)

# Re-compile (e.g., with a new learning rate).
classifier.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=keras.optimizers.Adam(5e-5),
    jit_compile=True,
)
# Access backbone programatically (e.g., to change `trainable`).
classifier.backbone.trainable = False
# Fit again.
classifier.fit(x=features, y=labels, batch_size=2)

Preprocessed integer data.

features = {
    "token_ids": tf.ones(shape=(2, 12), dtype=tf.int64),
    "segment_ids": tf.constant(
        [[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0]] * 2, shape=(2, 12)
    ),
    "padding_mask": tf.constant(
        [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]] * 2, shape=(2, 12)
    ),
}
labels = [0, 3]

# Pretrained classifier without preprocessing.
classifier = keras_nlp.models.BertClassifier.from_preset(
    "bert_base_en_uncased",
    num_classes=4,
    preprocessor=None,
)
classifier.fit(x=features, y=labels, batch_size=2)

Custom backbone and vocabulary.

features = ["The quick brown fox jumped.", "I forgot my homework."]
labels = [0, 3]

vocab = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
vocab += ["The", "quick", "brown", "fox", "jumped", "."]
tokenizer = keras_nlp.models.BertTokenizer(
    vocabulary=vocab,
)
preprocessor = keras_nlp.models.BertPreprocessor(
    tokenizer=tokenizer,
    sequence_length=128,
)
backbone = keras_nlp.models.BertBackbone(
    vocabulary_size=30552,
    num_layers=4,
    num_heads=4,
    hidden_dim=256,
    intermediate_dim=512,
    max_sequence_length=128,
)
classifier = keras_nlp.models.BertClassifier(
    backbone=backbone,
    preprocessor=preprocessor,
    num_classes=4,
)
classifier.fit(x=features, y=labels, batch_size=2)

[source]

from_preset method

BertClassifier.from_preset()

Instantiate BertClassifier model from preset architecture and weights.

Arguments

  • preset: string. Must be one of "bert_tiny_en_uncased", "bert_small_en_uncased", "bert_medium_en_uncased", "bert_base_en_uncased", "bert_base_en", "bert_base_zh", "bert_base_multi", "bert_large_en_uncased", "bert_large_en", "bert_tiny_en_uncased_sst2".
  • load_weights: Whether to load pre-trained weights into model. Defaults to True.

Examples

# Load architecture and weights from preset
model = BertClassifier.from_preset("bert_tiny_en_uncased")

# Load randomly initialized model from preset architecture
model = BertClassifier.from_preset(
    "bert_tiny_en_uncased",
    load_weights=False
)
Preset name Parameters Description
bert_tiny_en_uncased 4.39M 2-layer BERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus.
bert_small_en_uncased 28.76M 4-layer BERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus.
bert_medium_en_uncased 41.37M 8-layer BERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus.
bert_base_en_uncased 109.48M 12-layer BERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus.
bert_base_en 108.31M 12-layer BERT model where case is maintained. Trained on English Wikipedia + BooksCorpus.
bert_base_zh 102.27M 12-layer BERT model. Trained on Chinese Wikipedia.
bert_base_multi 177.85M 12-layer BERT model where case is maintained. Trained on trained on Wikipedias of 104 languages
bert_large_en_uncased 335.14M 24-layer BERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus.
bert_large_en 333.58M 24-layer BERT model where case is maintained. Trained on English Wikipedia + BooksCorpus.
bert_tiny_en_uncased_sst2 4.39M The bert_tiny_en_uncased backbone model fine-tuned on the SST-2 sentiment analysis dataset.

backbone property

keras_nlp.models.BertClassifier.backbone

A keras.Model instance providing the backbone submodel.


preprocessor property

keras_nlp.models.BertClassifier.preprocessor

A keras.layers.Layer instance used to preprocess inputs.