MultiHeadAttention classkeras.layers.MultiHeadAttention(
num_heads,
key_dim,
value_dim=None,
dropout=0.0,
use_bias=True,
output_shape=None,
attention_axes=None,
flash_attention=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
seed=None,
**kwargs
)
MultiHeadAttention layer.
This is an implementation of multi-headed attention as described in the
paper "Attention is all you Need"
Vaswani et al., 2017.
If query, key, value are the same, then
this is self-attention. Each timestep in query attends to the
corresponding sequence in key, and returns a fixed-width vector.
This layer first projects query, key and value. These are
(effectively) a list of tensors of length num_attention_heads, where the
corresponding shapes are (batch_size, <query dimensions>, key_dim),
(batch_size, <key/value dimensions>, key_dim),
(batch_size, <key/value dimensions>, value_dim).
Then, the query and key tensors are dot-producted and scaled. These are softmaxed to obtain attention probabilities. The value tensors are then interpolated by these probabilities, then concatenated back to a single tensor.
Finally, the result tensor with the last dimension as value_dim can take
a linear projection and return.
Arguments
None means
attention over all axes, but batch, heads, and features.None, the layer attempts to use flash
attention for faster and more memory-efficient attention
computations when possible. This behavior can be configured using
keras.config.enable_flash_attention() or
keras.config.disable_flash_attention().Call arguments
(B, T, dim), where B is the batch size,
T is the target sequence length, and dim is the feature dimension.(B, S, dim), where B is the batch size,
S is the source sequence length, and dim is the feature dimension.(B, S, dim). If not given, will
use value for both key and value, which is the most common
case.(B, T, S), that prevents
attention to certain positions. The boolean mask specifies which
query elements can attend to which key elements, 1 indicates
attention and 0 indicates no attention. Broadcasting can happen for
the missing batch dimensions and the head dimension.(attention_output, attention_scores) if True, or
attention_output if False. Defaults to False.False (inference) if there is no parent layer.Returns
(B, T, E),
where T is for target sequence shapes and E is the query input
last dimension if output_shape is None. Otherwise, the
multi-head outputs are projected to the shape specified by
output_shape.