GroupedQueryAttention classkeras.layers.GroupQueryAttention(
    head_dim,
    num_query_heads,
    num_key_value_heads,
    dropout=0.0,
    use_bias=True,
    flash_attention=None,
    kernel_initializer="glorot_uniform",
    bias_initializer="zeros",
    kernel_regularizer=None,
    bias_regularizer=None,
    activity_regularizer=None,
    kernel_constraint=None,
    bias_constraint=None,
    seed=None,
    **kwargs
)
Grouped Query Attention layer.
This is an implementation of grouped-query attention introduced by
Ainslie et al., 2023. Here
num_key_value_heads denotes number of groups, setting
num_key_value_heads to 1 is equivalent to multi-query attention, and
when num_key_value_heads is equal to num_query_heads it is equivalent
to multi-head attention.
This layer first projects query, key, and value tensors. Then, key
and value are repeated to match the number of heads of query.
Then, the query is scaled and dot-producted with key tensors. These are
softmaxed to obtain attention probabilities. The value tensors are then
interpolated by these probabilities and concatenated back to a single
tensor.
Arguments
None, the layer attempts to use flash
  attention for faster and more memory-efficient attention
  computations when possible. This behavior can be configured using
  keras.config.enable_flash_attention() or
  keras.config.disable_flash_attention().Call arguments
(batch_dim, target_seq_len, feature_dim),
  where batch_dim is batch size, target_seq_len is the length of
  target sequence, and feature_dim is dimension of feature.(batch_dim, source_seq_len, feature_dim),
  where batch_dim is batch size, source_seq_len is the length of
  source sequence, and feature_dim is dimension of feature.(batch_dim, source_seq_len, feature_dim). If not given, will use
  value for both key and value, which is most common case.(batch_dim, target_seq_len, source_seq_len), that prevents
  attention to certain positions. The boolean mask specifies which
  query elements can attend to which key elements, where 1 indicates
  attention and 0 indicates no attention. Broadcasting can happen for
  the missing batch dimensions and the head dimension.(attention_output, attention_scores) if True, or
  attention_output if False. Defaults to False.False (inference) if there is no parent layer.Returns
(batch_dim, target_seq_len, feature_dim), where target_seq_len
  is for target sequence length and feature_dim is the query input
  last dim.(batch_dim, num_query_heads, target_seq_len, source_seq_len).