TransformerDecoder
classkeras_nlp.layers.TransformerDecoder(
intermediate_dim,
num_heads,
dropout=0,
activation="relu",
layer_norm_epsilon=1e-05,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
normalize_first=False,
**kwargs
)
Transformer decoder.
This class follows the architecture of the transformer decoder layer in the paper Attention is All You Need. Users can instantiate multiple instances of this class to stack up a decoder.
By default, this layer will apply a causal mask to the decoder attention
layer. You can also pass padding or attention masks directly to the layer
during call, e.g. with decoder_padding_mask
or decoder_attention_mask
.
This layer can be called with either one or two inputs. The number of inputs
must be consistent across all calls. The options are as follows:
layer(decoder_sequence)
: no cross-attention will be built into the
decoder block. This is useful when building a "decoder-only"
transformer such as GPT-2.
layer(decoder_sequence, encoder_sequence)
: cross-attention will be
built into the decoder block. This is useful when building an
"encoder-decoder" transformer, such as the original transformer
model described in Attention is All You Need.
Arguments
0.
.keras.activations
. the
activation function of feedforward network.
Defaults to "relu"
.1e-5
.keras.initializers
initializer.
The kernel initializer for the dense and multiheaded
attention layers. Defaults to "glorot_uniform"
.keras.initializers
initializer.
The bias initializer for the dense and multiheaded
attention layers. Defaults to "zeros"
.False
.keras.layers.Layer
,
including name
, trainable
, dtype
etc.Example
# Create a single transformer decoder layer.
decoder = keras_nlp.layers.TransformerDecoder(
intermediate_dim=64, num_heads=8)
# Create a simple model containing the decoder.
decoder_input = keras.Input(shape=(10, 64))
encoder_input = keras.Input(shape=(10, 64))
output = decoder(decoder_input, encoder_input)
model = keras.Model(
inputs=(decoder_input, encoder_input),
outputs=output,
)
# Call decoder on the inputs.
decoder_input_data = np.random.uniform(size=(2, 10, 64))
encoder_input_data = np.random.uniform(size=(2, 10, 64))
decoder_output = model((decoder_input_data, encoder_input_data))
References
call
methodTransformerDecoder.call(
decoder_sequence,
encoder_sequence=None,
decoder_padding_mask=None,
decoder_attention_mask=None,
encoder_padding_mask=None,
encoder_attention_mask=None,
self_attention_cache=None,
self_attention_cache_update_index=None,
cross_attention_cache=None,
cross_attention_cache_update_index=None,
use_causal_mask=True,
training=None,
)
Forward pass of the TransformerDecoder.
Arguments
None
. Once the
model is called once without an encoder_sequence, you cannot
call it again with encoder_sequence.[batch_size, decoder_sequence_length]
.[batch_size, decoder_sequence_length, decoder_sequence_length]
.[batch_size, encoder_sequence_length]
.[batch_size, encoder_sequence_length, encoder_sequence_length]
.[batch_size, 2, max_seq_len, num_heads, key_dims]
.self_attention_cache
. Usually, this is
the index of the current token being processed during decoding.[batch_size, 2, S, num_heads, key_dims]
.cross_attention_cache
. Usually, this is
either 0
(compute the entire cross_attention_cache
), or
None
(reuse a previously computed cross_attention_cache
).True
. If true, a causal mask
(masking out future input) is applied `on the decoder sequence.Returns
outputs
, if self_attention_cache
is `None.(outputs, self_attention_cache)
, if self_attention_cache
is
set and the layer has no cross-attention.(outputs, self_attention_cache, cross_attention_cache)
, if
self_attention_cache
and cross_attention_cache
are set and
the layer has cross-attention.