nemo_automodel.components.models.qwen3_moe.layers
#
Module Contents#
Classes#
Qwen3 MoE attention (query/key per-head RMSNorm + RoPE) compatible with TE/SDPA backends. |
Functions#
Preprocess attention inputs based on backend requirements. |
|
API#
- nemo_automodel.components.models.qwen3_moe.layers._preprocess_for_attn(
- q: torch.Tensor,
- k: torch.Tensor,
- v: torch.Tensor,
- attention_mask: torch.Tensor | None,
- backend: nemo_automodel.components.moe.utils.BackendConfig,
Preprocess attention inputs based on backend requirements.
Mirrors deepseek_v3.layers.preprocess_args_and_kwargs_for_attn but inlined to avoid import cycles.
- nemo_automodel.components.models.qwen3_moe.layers._postprocess_from_attn(
- x: torch.Tensor,
- backend: nemo_automodel.components.moe.utils.BackendConfig,
- class nemo_automodel.components.models.qwen3_moe.layers.Qwen3MoeAttention(
- config: transformers.models.qwen3_moe.configuration_qwen3_moe.Qwen3MoeConfig,
- backend: nemo_automodel.components.moe.utils.BackendConfig,
Bases:
torch.nn.Module
Qwen3 MoE attention (query/key per-head RMSNorm + RoPE) compatible with TE/SDPA backends.
Shapes:
Input: x -> [B, S, H]
Projections: q: [B, S, n_heads, head_dim] k/v: [B, S, n_kv_heads, head_dim] -> repeated to n_heads via groups
Output: [B, S, H]
Initialization
- forward(
- x: torch.Tensor,
- *,
- freqs_cis: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- **attn_kwargs: Any,
- init_weights(buffer_device: torch.device, init_std: float = 0.02)#