Keras 3 API documentation / KerasNLP / Models / Bert / BertMaskedLM model

BertMaskedLM model

[source]

BertMaskedLM class

keras_nlp.models.BertMaskedLM(backbone, preprocessor=None, **kwargs)

An end-to-end BERT model for the masked language modeling task.

This model will train BERT on a masked language modeling task. The model will predict labels for a number of masked tokens in the input data. For usage of this model with pre-trained weights, see the from_preset() constructor.

This model can optionally be configured with a preprocessor layer, in which case inputs can be raw string features during fit(), predict(), and evaluate(). Inputs will be tokenized and dynamically masked during training and evaluation. This is done by default when creating the model with from_preset().

Disclaimer: Pre-trained models are provided on an "as is" basis, without warranties or conditions of any kind.

Arguments

Example usage:

Raw string data.

features = ["The quick brown fox jumped.", "I forgot my homework."]

# Pretrained language model.
masked_lm = keras_nlp.models.BertMaskedLM.from_preset(
    "bert_base_en_uncased",
)
masked_lm.fit(x=features, batch_size=2)

# Re-compile (e.g., with a new learning rate).
masked_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=keras.optimizers.Adam(5e-5),
    jit_compile=True,
)
# Access backbone programmatically (e.g., to change `trainable`).
masked_lm.backbone.trainable = False
# Fit again.
masked_lm.fit(x=features, batch_size=2)

Preprocessed integer data.

# Create preprocessed batch where 0 is the mask token.
features = {
    "token_ids": np.array([[1, 2, 0, 4, 0, 6, 7, 8]] * 2),
    "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1]] * 2),
    "mask_positions": np.array([[2, 4]] * 2),
    "segment_ids": np.array([[0, 0, 0, 0, 0, 0, 0, 0]] * 2)
}
# Labels are the original masked values.
labels = [[3, 5]] * 2

masked_lm = keras_nlp.models.BertMaskedLM.from_preset(
    "bert_base_en_uncased",
    preprocessor=None,
)
masked_lm.fit(x=features, y=labels, batch_size=2)

[source]

from_preset method

BertMaskedLM.from_preset()

Instantiate BertMaskedLM model from preset architecture and weights.

Arguments

  • preset: string. Must be one of "bert_tiny_en_uncased", "bert_small_en_uncased", "bert_medium_en_uncased", "bert_base_en_uncased", "bert_base_en", "bert_base_zh", "bert_base_multi", "bert_large_en_uncased", "bert_large_en".
  • load_weights: Whether to load pre-trained weights into model. Defaults to True.

Examples

# Load architecture and weights from preset
model = BertMaskedLM.from_preset("bert_tiny_en_uncased")

# Load randomly initialized model from preset architecture
model = BertMaskedLM.from_preset(
    "bert_tiny_en_uncased",
    load_weights=False
)
Preset name Parameters Description
bert_tiny_en_uncased 4.39M 2-layer BERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus.
bert_small_en_uncased 28.76M 4-layer BERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus.
bert_medium_en_uncased 41.37M 8-layer BERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus.
bert_base_en_uncased 109.48M 12-layer BERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus.
bert_base_en 108.31M 12-layer BERT model where case is maintained. Trained on English Wikipedia + BooksCorpus.
bert_base_zh 102.27M 12-layer BERT model. Trained on Chinese Wikipedia.
bert_base_multi 177.85M 12-layer BERT model where case is maintained. Trained on trained on Wikipedias of 104 languages
bert_large_en_uncased 335.14M 24-layer BERT model where all input is lowercased. Trained on English Wikipedia + BooksCorpus.
bert_large_en 333.58M 24-layer BERT model where case is maintained. Trained on English Wikipedia + BooksCorpus.

backbone property

keras_nlp.models.BertMaskedLM.backbone

A keras.Model instance providing the backbone sub-model.


preprocessor property

keras_nlp.models.BertMaskedLM.preprocessor

A keras.layers.Layer instance used to preprocess inputs.