Keras 3 API documentation / KerasNLP / Modeling Layers / MaskedLMHead layer

MaskedLMHead layer


MaskedLMHead class


Masked Language Model (MaskedLM) head.

This layer takes two inputs:

  • inputs: which should be a tensor of encoded tokens with shape (batch_size, sequence_length, hidden_dim).
  • mask_positions: which should be a tensor of integer positions to predict with shape (batch_size, masks_per_sequence).

The token encodings should usually be the last output of an encoder model, and mask positions should be the integer positions you would like to predict for the MaskedLM task.

The layer will first gather the token encodings at the mask positions. These gathered tokens will be passed through a dense layer the same size as encoding dimension, then transformed to predictions the same size as the input vocabulary. This layer will produce a single output with shape (batch_size, masks_per_sequence, vocabulary_size), which can be used to compute an MaskedLM loss function.

This layer is often be paired with keras_nlp.layers.MaskedLMMaskGenerator, which will help prepare inputs for the MaskedLM task.


  • vocabulary_size: The total size of the vocabulary for predictions.
  • token_embedding: Optional. A keras_nlp.layers.ReversibleEmbedding instance. If passed, the layer will be used to project from the hidden_dim of the model to the output vocabulary_size.
  • intermediate_activation: The activation function of intermediate dense layer.
  • activation: The activation function for the outputs of the layer. Usually either None (return logits), or "softmax" (return probabilities).
  • layer_norm_epsilon: float. The epsilon value in layer normalization components. Defaults to 1e-5.
  • kernel_initializer: string or keras.initializers initializer. The kernel initializer for the dense and multiheaded attention layers. Defaults to "glorot_uniform".
  • bias_initializer: string or keras.initializers initializer. The bias initializer for the dense and multiheaded attention layers. Defaults to "zeros".
  • **kwargs: other keyword arguments passed to keras.layers.Layer, including name, trainable, dtype etc.


batch_size = 16
vocab_size = 100
hidden_dim = 32
seq_length = 50

# Generate random inputs.
token_ids = np.random.randint(vocab_size, size=(batch_size, seq_length))
# Choose random positions as the masked inputs.
mask_positions = np.random.randint(seq_length, size=(batch_size, 5))

# Embed tokens in a `hidden_dim` feature space.
token_embedding = keras_nlp.layers.ReversibleEmbedding(
hidden_states = token_embedding(token_ids)

preds = keras_nlp.layers.MaskedLMHead(
)(hidden_states, mask_positions)