ReversibleEmbedding class


An embedding layer which can project backwards to the input dim.

This layer is an extension of keras.layers.Embedding for language models. This layer can be called "in reverse" with reverse=True, in which case the layer will linearly project from output_dim back to input_dim.

By default, the reverse projection will use the transpose of the embeddings weights to project to input_dim (weights are "tied"). If tie_weights=False, the model will use a separate, trainable variable for reverse projection.

This layer has no bias terms.


  • input_dim: Integer. Size of the vocabulary, i.e. maximum integer index + 1.
  • output_dim: Integer. Dimension of the dense embedding.
  • tie_weights: Boolean, whether or not the matrix for embedding and the matrix for the reverse projection should share the same weights.
  • embeddings_initializer: Initializer for the embeddings matrix (see keras.initializers).
  • embeddings_regularizer: Regularizer function applied to the embeddings matrix (see keras.regularizers).
  • embeddings_constraint: Constraint function applied to the embeddings matrix (see keras.constraints).
  • mask_zero: Boolean, whether or not the input value 0 is a special "padding" value that should be masked out.
  • reverse_dtype: The dtype for the reverse projection computation. For stability, it is usually best to use full precision even when working with half or mixed precision training.
  • **kwargs: other keyword arguments passed to keras.layers.Embedding, including name, trainable, dtype etc.

Call arguments

  • inputs: The tensor inputs to the layer.
  • reverse: Boolean. If True the layer will perform a linear projection from output_dim to input_dim, instead of a normal embedding call. Default to False.


batch_size = 16
vocab_size = 100
hidden_dim = 32
seq_length = 50

# Generate random inputs.
token_ids = np.random.randint(vocab_size, size=(batch_size, seq_length))

embedding = keras_nlp.layers.ReversibleEmbedding(vocab_size, hidden_dim)
# Embed tokens to shape `(batch_size, seq_length, hidden_dim)`.
hidden_states = embedding(token_ids)
# Project hidden states to shape `(batch_size, seq_length, vocab_size)`.
logits = embedding(hidden_states, reverse=True)