GroupedQueryAttention classkeras.layers.GroupQueryAttention(
head_dim,
num_query_heads,
num_key_value_heads,
dropout=0.0,
use_bias=True,
flash_attention=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
seed=None,
**kwargs
)
Grouped Query Attention layer.
This is an implementation of grouped-query attention introduced by
Ainslie et al., 2023. Here
num_key_value_heads denotes number of groups, setting
num_key_value_heads to 1 is equivalent to multi-query attention, and
when num_key_value_heads is equal to num_query_heads it is equivalent
to multi-head attention.
This layer first projects query, key, and value tensors. Then, key
and value are repeated to match the number of heads of query.
Then, the query is scaled and dot-producted with key tensors. These are
softmaxed to obtain attention probabilities. The value tensors are then
interpolated by these probabilities and concatenated back to a single
tensor.
Arguments
None, the layer attempts to use flash
attention for faster and more memory-efficient attention
computations when possible. This behavior can be configured using
keras.config.enable_flash_attention() or
keras.config.disable_flash_attention().Call arguments
(batch_dim, target_seq_len, feature_dim),
where batch_dim is batch size, target_seq_len is the length of
target sequence, and feature_dim is dimension of feature.(batch_dim, source_seq_len, feature_dim),
where batch_dim is batch size, source_seq_len is the length of
source sequence, and feature_dim is dimension of feature.(batch_dim, source_seq_len, feature_dim). If not given, will use
value for both key and value, which is most common case.(batch_dim, target_seq_len, source_seq_len), that prevents
attention to certain positions. The boolean mask specifies which
query elements can attend to which key elements, where 1 indicates
attention and 0 indicates no attention. Broadcasting can happen for
the missing batch dimensions and the head dimension.(attention_output, attention_scores) if True, or
attention_output if False. Defaults to False.False (inference) if there is no parent layer.Returns
(batch_dim, target_seq_len, feature_dim), where target_seq_len
is for target sequence length and feature_dim is the query input
last dim.(batch_dim, num_query_heads, target_seq_len, source_seq_len).