Gemma4AudioEncoder model

[source]

Gemma4AudioEncoder class

keras_hub.models.Gemma4AudioEncoder(
    input_feat_size=128,
    hidden_size=1024,
    num_heads=8,
    num_layers=12,
    chunk_size=12,
    context_left=13,
    context_right=0,
    logit_cap=50.0,
    invalid_logit_value=-1000000000.0,
    conv_kernel_size=5,
    reduction_factor=1,
    residual_weight=0.5,
    gradient_clipping=10000000000.0,
    sscp_conv_channels=(128, 32),
    sscp_kernel_sizes=((3, 3), (3, 3)),
    sscp_stride_sizes=((2, 2), (2, 2)),
    output_proj_dims=1536,
    output_dim=2048,
    norm_eps=1e-06,
    sscp_norm_eps=1e-06,
    dtype=None,
    **kwargs
)

Audio encoder for Gemma4 based on the Universal Speech Model (USM).

Encodes mel spectrograms into audio token embeddings projected into the language model's hidden space. The pipeline is:

  1. SubSampleConvProjection: two stacked Conv2D blocks that downsample time by 4× at 16ms hop rate, then a linear projection to hidden_size.
  2. Conformer blocks (num_layers of them): macaron-FFW → chunk attention with relative position bias → causal depthwise Conv1D → macaron-FFW → RMS norm.
  3. Temporal striding (if reduction_factor > 1): reduce sequence by taking every reduction_factor-th token.
  4. Output projection: linear hidden_size → output_proj_dims followed by another linear output_proj_dims → output_dim (= text hidden size) and a parameter-free RMS norm.

Padded positions (indicated by audio_mel_mask) are zeroed out in the final output.

Arguments

  • input_feat_size: int. Number of mel filterbank channels. Defaults to 128.
  • hidden_size: int. Conformer hidden dimension. Defaults to 1024.
  • num_heads: int. Number of conformer attention heads. Defaults to 8.
  • num_layers: int. Number of Conformer blocks. Defaults to 12.
  • chunk_size: int. Block size for chunk-based attention. Defaults to 12.
  • context_left: int. Left attention context (inclusive). Defaults to 13.
  • context_right: int. Right attention context. Defaults to 0.
  • logit_cap: float. Soft-cap on attention logits. Defaults to 50.0.
  • invalid_logit_value: float. Fill for masked logits. Defaults to -1e9.
  • conv_kernel_size: int. Depthwise conv kernel size. Defaults to 5.
  • reduction_factor: int. Temporal stride after the conformer stack. Defaults to 1.
  • residual_weight: float. Macaron FFW residual weight. Defaults to 0.5.
  • gradient_clipping: float. Clip value. Defaults to 1e10.
  • sscp_conv_channels: tuple of two ints. Output channels per SSCP conv. Defaults to (128, 32).
  • sscp_kernel_sizes: tuple of two (kT, kF) pairs. Defaults to ((3, 3), (3, 3)).
  • sscp_stride_sizes: tuple of two (sT, sF) pairs. Defaults to ((2, 2), (2, 2)).
  • output_proj_dims: int or None. Intermediate audio projection dimension (e.g. 1536). None skips this projection.
  • output_dim: int. Final output dimension = text backbone hidden size.
  • norm_eps: float. Epsilon for conformer RMS norms. Defaults to 1e-6.
  • sscp_norm_eps: float. Epsilon for SSCP LayerNorm. Defaults to 1e-6.
  • dtype: Compute dtype. Defaults to None.