Keras 3 API documentation / KerasNLP / Models / DebertaV3 / DebertaV3Classifier model

DebertaV3Classifier model

[source]

DebertaV3Classifier class

keras_nlp.models.DebertaV3Classifier(
    backbone,
    num_classes,
    preprocessor=None,
    activation=None,
    hidden_dim=None,
    dropout=0.0,
    **kwargs
)

An end-to-end DeBERTa model for classification tasks.

This model attaches a classification head to a keras_nlp.model.DebertaV3Backbone model, mapping from the backbone outputs to logit output suitable for a classification task. For usage of this model with pre-trained weights, see the from_preset() method.

This model can optionally be configured with a preprocessor layer, in which case it will automatically apply preprocessing to raw inputs during fit(), predict(), and evaluate(). This is done by default when creating the model with from_preset().

Note: DebertaV3Backbone has a performance issue on TPUs, and we recommend other models for TPU training and inference.

Disclaimer: Pre-trained models are provided on an "as is" basis, without warranties or conditions of any kind. The underlying model is provided by a third party and subject to a separate license, available here.

Arguments

  • backbone: A keras_nlp.models.DebertaV3 instance.
  • num_classes: int. Number of classes to predict.
  • preprocessor: A keras_nlp.models.DebertaV3Preprocessor or None. If None, this model will not apply preprocessing, and inputs should be preprocessed before calling the model.
  • activation: Optional str or callable. The activation function to use on the model outputs. Set activation="softmax" to return output probabilities. Defaults to None.
  • hidden_dim: int. The size of the pooler layer.
  • dropout: float. Dropout probability applied to the pooled output. For the second dropout layer, backbone.dropout is used.

Examples

Raw string data.

features = ["The quick brown fox jumped.", "I forgot my homework."]
labels = [0, 3]

# Pretrained classifier.
classifier = keras_nlp.models.DebertaV3Classifier.from_preset(
    "deberta_v3_base_en",
    num_classes=4,
)
classifier.fit(x=features, y=labels, batch_size=2)
classifier.predict(x=features, batch_size=2)

# Re-compile (e.g., with a new learning rate).
classifier.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=keras.optimizers.Adam(5e-5),
    jit_compile=True,
)
# Access backbone programmatically (e.g., to change `trainable`).
classifier.backbone.trainable = False
# Fit again.
classifier.fit(x=features, y=labels, batch_size=2)

Preprocessed integer data.

features = {
    "token_ids": np.ones(shape=(2, 12), dtype="int32"),
    "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]] * 2),
}
labels = [0, 3]

# Pretrained classifier without preprocessing.
classifier = keras_nlp.models.DebertaV3Classifier.from_preset(
    "deberta_v3_base_en",
    num_classes=4,
    preprocessor=None,
)
classifier.fit(x=features, y=labels, batch_size=2)

Custom backbone and vocabulary.

features = ["The quick brown fox jumped.", "I forgot my homework."]
labels = [0, 3]

bytes_io = io.BytesIO()
ds = tf.data.Dataset.from_tensor_slices(features)
sentencepiece.SentencePieceTrainer.train(
    sentence_iterator=ds.as_numpy_iterator(),
    model_writer=bytes_io,
    vocab_size=10,
    model_type="WORD",
    pad_id=0,
    bos_id=1,
    eos_id=2,
    unk_id=3,
    pad_piece="[PAD]",
    bos_piece="[CLS]",
    eos_piece="[SEP]",
    unk_piece="[UNK]",
)
tokenizer = keras_nlp.models.DebertaV3Tokenizer(
    proto=bytes_io.getvalue(),
)
preprocessor = keras_nlp.models.DebertaV3Preprocessor(
    tokenizer=tokenizer,
    sequence_length=128,
)
backbone = keras_nlp.models.DebertaV3Backbone(
    vocabulary_size=30552,
    num_layers=4,
    num_heads=4,
    hidden_dim=256,
    intermediate_dim=512,
    max_sequence_length=128,
)
classifier = keras_nlp.models.DebertaV3Classifier(
    backbone=backbone,
    preprocessor=preprocessor,
    num_classes=4,
)
classifier.fit(x=features, y=labels, batch_size=2)

[source]

from_preset method

DebertaV3Classifier.from_preset()

Instantiate DebertaV3Classifier model from preset architecture and weights.

Arguments

  • preset: string. Must be one of "deberta_v3_extra_small_en", "deberta_v3_small_en", "deberta_v3_base_en", "deberta_v3_large_en", "deberta_v3_base_multi".
  • load_weights: Whether to load pre-trained weights into model. Defaults to True.

Examples

# Load architecture and weights from preset
model = DebertaV3Classifier.from_preset("deberta_v3_extra_small_en")

# Load randomly initialized model from preset architecture
model = DebertaV3Classifier.from_preset(
    "deberta_v3_extra_small_en",
    load_weights=False
)
Preset name Parameters Description
deberta_v3_extra_small_en 70.68M 12-layer DeBERTaV3 model where case is maintained. Trained on English Wikipedia, BookCorpus and OpenWebText.
deberta_v3_small_en 141.30M 6-layer DeBERTaV3 model where case is maintained. Trained on English Wikipedia, BookCorpus and OpenWebText.
deberta_v3_base_en 183.83M 12-layer DeBERTaV3 model where case is maintained. Trained on English Wikipedia, BookCorpus and OpenWebText.
deberta_v3_large_en 434.01M 24-layer DeBERTaV3 model where case is maintained. Trained on English Wikipedia, BookCorpus and OpenWebText.
deberta_v3_base_multi 278.22M 12-layer DeBERTaV3 model where case is maintained. Trained on the 2.5TB multilingual CC100 dataset.

backbone property

keras_nlp.models.DebertaV3Classifier.backbone

A keras.Model instance providing the backbone sub-model.


preprocessor property

keras_nlp.models.DebertaV3Classifier.preprocessor

A keras.layers.Layer instance used to preprocess inputs.