nemo_automodel.components.models.qwen3_moe.layers

View as Markdown

Module Contents

Classes

NameDescription
Qwen3MoeAttentionQwen3 MoE attention (query/key per-head RMSNorm + RoPE) compatible with TE/SDPA backends.

Data

logger

API

class nemo_automodel.components.models.qwen3_moe.layers.Qwen3MoeAttention(
config: transformers.models.qwen3_moe.configuration_qwen3_moe.Qwen3MoeConfig,
backend: nemo_automodel.components.models.common.BackendConfig
)

Bases: Module

Qwen3 MoE attention (query/key per-head RMSNorm + RoPE) compatible with TE/SDPA backends.

head_dim
k_norm
k_proj
num_heads
= config.num_attention_heads
num_kv_heads
= config.num_key_value_heads
o_proj
q_norm
q_proj
v_proj
nemo_automodel.components.models.qwen3_moe.layers.Qwen3MoeAttention._forward_impl(
x: torch.Tensor,
freqs_cis: torch.Tensor,
attention_mask: torch.Tensor | None = None,
attn_kwargs: typing.Any = {}
) -> torch.Tensor
nemo_automodel.components.models.qwen3_moe.layers.Qwen3MoeAttention.forward(
x: torch.Tensor,
freqs_cis: torch.Tensor,
attention_mask: torch.Tensor | None = None,
attn_kwargs: typing.Any = {}
) -> torch.Tensor
nemo_automodel.components.models.qwen3_moe.layers.Qwen3MoeAttention.init_weights(
buffer_device: torch.device,
init_std: float = 0.02
)
nemo_automodel.components.models.qwen3_moe.layers.logger = logging.getLogger(__name__)