GroupedQueryAttention
classkeras.layers.GroupQueryAttention(
head_dim,
num_query_heads,
num_key_value_heads,
dropout=0.0,
use_bias=True,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs
)
Grouped Query Attention layer.
This is an implementation of grouped-query attention introduced by
Ainslie et al., 2023. Here
num_key_value_heads
denotes number of groups, setting
num_key_value_heads
to 1 is equivalent to multi-query attention, and
when num_key_value_heads
is equal to num_query_heads
it is equivalent
to multi-head attention.
This layer first projects query
, key
, and value
tensors. Then, key
and value
are repeated to match the number of heads of query
.
Then, the query
is scaled and dot-producted with key
tensors. These are
softmaxed to obtain attention probabilities. The value tensors are then
interpolated by these probabilities and concatenated back to a single
tensor.
Arguments
Call arguments
(batch_dim, target_seq_len, feature_dim)
,
where batch_dim
is batch size, target_seq_len
is the length of
target sequence, and feature_dim
is dimension of feature.(batch_dim, source_seq_len, feature_dim)
,
where batch_dim
is batch size, source_seq_len
is the length of
source sequence, and feature_dim
is dimension of feature.(batch_dim, source_seq_len, feature_dim)
. If not given, will use
value
for both key
and value
, which is most common case.(batch_dim, target_seq_len, source_seq_len)
, that prevents
attention to certain positions. The boolean mask specifies which
query elements can attend to which key elements, where 1 indicates
attention and 0 indicates no attention. Broadcasting can happen for
the missing batch dimensions and the head dimension.(attention_output, attention_scores)
if True
, or
attention_output
if False
. Defaults to False
.False
(inference) if there is no parent layer.Returns
(batch_dim, target_seq_len, feature_dim)
, where target_seq_len
is for target sequence length and feature_dim
is the query input
last dim.(batch_dim, num_query_heads, target_seq_len, source_seq_len)
.