nemo_automodel.components.models.hy_mt2.layers#

Module Contents#

Classes#

HyMT2Attention

Hy-MT2-30B-A3B attention: GQA, per-head Q/K RMSNorm, and RoPE.

API#

class nemo_automodel.components.models.hy_mt2.layers.HyMT2Attention(
config: Any,
backend: nemo_automodel.components.models.common.BackendConfig,
)#

Bases: torch.nn.Module

Hy-MT2-30B-A3B attention: GQA, per-head Q/K RMSNorm, and RoPE.

Differences vs. the existing Hy3-preview HYV3Attention:

  • qk_norm is gated by config.qk_norm (defaults to True). For Hy-MT2-30B-A3B this is always True; the flag is here so the same module can also be reused for non-qk-norm variants without code edits.

  • Dimensions follow Hy-MT2-30B-A3B: 32 Q heads / 4 KV heads, head_dim=128, hidden_size=2048.

Initialization

forward(
x: torch.Tensor,
*,
freqs_cis: torch.Tensor,
attention_mask: torch.Tensor | None = None,
**attn_kwargs: Any,
) torch.Tensor#
init_weights(buffer_device: torch.device, init_std: float = 0.02)#