MultiHeadAttention
classkeras.layers.MultiHeadAttention(
num_heads,
key_dim,
value_dim=None,
dropout=0.0,
use_bias=True,
output_shape=None,
attention_axes=None,
flash_attention=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
seed=None,
**kwargs
)
MultiHeadAttention layer.
This is an implementation of multi-headed attention as described in the
paper "Attention is all you Need"
Vaswani et al., 2017.
If query
, key,
value
are the same, then
this is self-attention. Each timestep in query
attends to the
corresponding sequence in key
, and returns a fixed-width vector.
This layer first projects query
, key
and value
. These are
(effectively) a list of tensors of length num_attention_heads
, where the
corresponding shapes are (batch_size, <query dimensions>, key_dim)
,
(batch_size, <key/value dimensions>, key_dim)
,
(batch_size, <key/value dimensions>, value_dim)
.
Then, the query and key tensors are dot-producted and scaled. These are softmaxed to obtain attention probabilities. The value tensors are then interpolated by these probabilities, then concatenated back to a single tensor.
Finally, the result tensor with the last dimension as value_dim
can take
a linear projection and return.
Arguments
None
means
attention over all axes, but batch, heads, and features.None
, the layer attempts to use flash
attention for faster and more memory-efficient attention
computations when possible. This behavior can be configured using
keras.config.enable_flash_attention()
or
keras.config.disable_flash_attention()
.Call arguments
(B, T, dim)
, where B
is the batch size,
T
is the target sequence length, and dim is the feature dimension.(B, S, dim)
, where B
is the batch size,
S
is the source sequence length, and dim is the feature dimension.(B, S, dim)
. If not given, will
use value
for both key
and value
, which is the most common
case.(B, T, S)
, that prevents
attention to certain positions. The boolean mask specifies which
query elements can attend to which key elements, 1 indicates
attention and 0 indicates no attention. Broadcasting can happen for
the missing batch dimensions and the head dimension.(attention_output, attention_scores)
if True
, or
attention_output
if False
. Defaults to False
.False
(inference) if there is no parent layer.Returns
(B, T, E)
,
where T
is for target sequence shapes and E
is the query input
last dimension if output_shape
is None
. Otherwise, the
multi-head outputs are projected to the shape specified by
output_shape
.