BertClassifier
classkeras_nlp.models.BertClassifier(
backbone, num_classes, preprocessor=None, activation=None, dropout=0.1, **kwargs
)
An end-to-end BERT model for classification tasks.
This model attaches a classification head to a
keras_nlp.model.BertBackbone
instance, mapping from the backbone outputs
to logits suitable for a classification task. For usage of this model with
pre-trained weights, use the from_preset()
constructor.
This model can optionally be configured with a preprocessor
layer, in
which case it will automatically apply preprocessing to raw inputs during
fit()
, predict()
, and evaluate()
. This is done by default when
creating the model with from_preset()
.
Disclaimer: Pre-trained models are provided on an "as is" basis, without warranties or conditions of any kind.
Arguments
keras_nlp.models.BertBackbone
instance.keras_nlp.models.BertPreprocessor
or None
. If
None
, this model will not apply preprocessing, and inputs should
be preprocessed before calling the model.str
or callable, defaults to None
. The
activation function to use on the model outputs. Set
activation="softmax"
to return output probabilities.Examples
Raw string data.
features = ["The quick brown fox jumped.", "I forgot my homework."]
labels = [0, 3]
# Pretrained classifier.
classifier = keras_nlp.models.BertClassifier.from_preset(
"bert_base_en_uncased",
num_classes=4,
)
classifier.fit(x=features, y=labels, batch_size=2)
classifier.predict(x=features, batch_size=2)
# Re-compile (e.g., with a new learning rate).
classifier.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.Adam(5e-5),
jit_compile=True,
)
# Access backbone programatically (e.g., to change `trainable`).
classifier.backbone.trainable = False
# Fit again.
classifier.fit(x=features, y=labels, batch_size=2)
Preprocessed integer data.
features = {
"token_ids": tf.ones(shape=(2, 12), dtype=tf.int64),
"segment_ids": tf.constant(
[[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0]] * 2, shape=(2, 12)
),
"padding_mask": tf.constant(
[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]] * 2, shape=(2, 12)
),
}
labels = [0, 3]
# Pretrained classifier without preprocessing.
classifier = keras_nlp.models.BertClassifier.from_preset(
"bert_base_en_uncased",
num_classes=4,
preprocessor=None,
)
classifier.fit(x=features, y=labels, batch_size=2)
Custom backbone and vocabulary.
features = ["The quick brown fox jumped.", "I forgot my homework."]
labels = [0, 3]
vocab = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
vocab += ["The", "quick", "brown", "fox", "jumped", "."]
tokenizer = keras_nlp.models.BertTokenizer(
vocabulary=vocab,
)
preprocessor = keras_nlp.models.BertPreprocessor(
tokenizer=tokenizer,
sequence_length=128,
)
backbone = keras_nlp.models.BertBackbone(
vocabulary_size=30552,
num_layers=4,
num_heads=4,
hidden_dim=256,
intermediate_dim=512,
max_sequence_length=128,
)
classifier = keras_nlp.models.BertClassifier(
backbone=backbone,
preprocessor=preprocessor,
num_classes=4,
)
classifier.fit(x=features, y=labels, batch_size=2)
from_preset
methodBertClassifier.from_preset()
Instantiate BertClassifier model from preset architecture and weights.
Arguments
True
.Examples
# Load architecture and weights from preset
model = BertClassifier.from_preset("bert_tiny_en_uncased")
# Load randomly initialized model from preset architecture
model = BertClassifier.from_preset(
"bert_tiny_en_uncased",
load_weights=False
)
Preset name | Parameters | Description |
---|---|---|
bert_tiny_en_uncased | 4.39M | 2-layer BERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus. |
bert_small_en_uncased | 28.76M | 4-layer BERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus. |
bert_medium_en_uncased | 41.37M | 8-layer BERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus. |
bert_base_en_uncased | 109.48M | 12-layer BERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus. |
bert_base_en | 108.31M | 12-layer BERT model where case is maintained. Trained on English Wikipedia + BooksCorpus. |
bert_base_zh | 102.27M | 12-layer BERT model. Trained on Chinese Wikipedia. |
bert_base_multi | 177.85M | 12-layer BERT model where case is maintained. Trained on trained on Wikipedias of 104 languages |
bert_large_en_uncased | 335.14M | 24-layer BERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus. |
bert_large_en | 333.58M | 24-layer BERT model where case is maintained. Trained on English Wikipedia + BooksCorpus. |
bert_tiny_en_uncased_sst2 | 4.39M | The bert_tiny_en_uncased backbone model fine-tuned on the SST-2 sentiment analysis dataset. |
backbone
propertykeras_nlp.models.BertClassifier.backbone
A keras.Model
instance providing the backbone submodel.
preprocessor
propertykeras_nlp.models.BertClassifier.preprocessor
A keras.layers.Layer
instance used to preprocess inputs.