nemo_automodel.components.models.qwen3_moe.layers#

Module Contents#

Classes#

Qwen3MoeAttention

Qwen3 MoE attention (query/key per-head RMSNorm + RoPE) compatible with TE/SDPA backends.

Functions#

_preprocess_for_attn

Preprocess attention inputs based on backend requirements.

_postprocess_from_attn

API#

nemo_automodel.components.models.qwen3_moe.layers._preprocess_for_attn(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attention_mask: torch.Tensor | None,
backend: nemo_automodel.components.moe.utils.BackendConfig,
) tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]#

Preprocess attention inputs based on backend requirements.

Mirrors deepseek_v3.layers.preprocess_args_and_kwargs_for_attn but inlined to avoid import cycles.

nemo_automodel.components.models.qwen3_moe.layers._postprocess_from_attn(
x: torch.Tensor,
backend: nemo_automodel.components.moe.utils.BackendConfig,
) torch.Tensor#
class nemo_automodel.components.models.qwen3_moe.layers.Qwen3MoeAttention(
config: transformers.models.qwen3_moe.configuration_qwen3_moe.Qwen3MoeConfig,
backend: nemo_automodel.components.moe.utils.BackendConfig,
)#

Bases: torch.nn.Module

Qwen3 MoE attention (query/key per-head RMSNorm + RoPE) compatible with TE/SDPA backends.

Shapes:

  • Input: x -> [B, S, H]

  • Projections: q: [B, S, n_heads, head_dim] k/v: [B, S, n_kv_heads, head_dim] -> repeated to n_heads via groups

  • Output: [B, S, H]

Initialization

forward(
x: torch.Tensor,
*,
freqs_cis: torch.Tensor,
attention_mask: torch.Tensor | None = None,
**attn_kwargs: Any,
) torch.Tensor#
init_weights(buffer_device: torch.device, init_std: float = 0.02)#