nemo_automodel.components.models.qwen3_next.layers

View as Markdown

Module Contents

Classes

NameDescription
Qwen3NextAttention-
Qwen3NextFp32GatedDeltaNetQwen3-Next GatedDeltaNet that computes the decay gate via an fp32 holder.
Qwen3NextRMSNorm-
Qwen3NextSSMGateOwns Qwen3-Next fp32 SSM-gating params and computes the decay gate.
_SSMGateParamGet-only descriptor exposing a param from _fp32_params when present.

Functions

NameDescription
_install_ssm_gateMove HF-created bare A_log/dt_bias into a native fp32 holder.

API

class nemo_automodel.components.models.qwen3_next.layers.Qwen3NextAttention(
config: transformers.models.qwen3_next.configuration_qwen3_next.Qwen3NextConfig,
layer_idx: int,
backend: nemo_automodel.components.models.common.BackendConfig
)

Bases: Module

head_dim
k_norm
k_proj
num_heads
= config.num_attention_heads
num_key_value_groups
= self.num_heads // self.num_kv_heads
num_kv_heads
= config.num_key_value_heads
o_proj
q_norm
q_proj
scaling
= self.head_dim ** -0.5
v_proj
nemo_automodel.components.models.qwen3_next.layers.Qwen3NextAttention.forward(
x: torch.Tensor,
freqs_cis: torch.Tensor,
attention_mask: torch.Tensor | None = None,
attn_kwargs: typing.Any = {}
) -> torch.Tensor
nemo_automodel.components.models.qwen3_next.layers.Qwen3NextAttention.init_weights(
buffer_device: torch.device,
init_std: float = 0.02
)
class nemo_automodel.components.models.qwen3_next.layers.Qwen3NextFp32GatedDeltaNet(
config: transformers.models.qwen3_next.configuration_qwen3_next.Qwen3NextConfig,
layer_idx: int
)

Bases: Qwen3NextGatedDeltaNet

Qwen3-Next GatedDeltaNet that computes the decay gate via an fp32 holder.

HF’s Qwen3NextGatedDeltaNet computes the gate inline as g = -exp(A_log) * softplus(a + dt_bias) using the bare A_log / dt_bias parameters. A_log and dt_bias are intrinsically fp32 (A_log is exponentiated, so bf16 rounding becomes a proportional error on the decay rate that the recurrence compounds across the sequence).

The constructor moves those params into a native _fp32_params holder so they are fp32 resident before any dtype cast or FSDP wrapping. To keep the gate computation in fp32 — and to make FSDP’s unshard/reshard + gradient reduce-scatter fire for that unit — the gate is computed inside the holder’s forward. This subclass overrides forward to route the gate through self._compute_gate(a) while reproducing the rest of HF’s forward verbatim.

A_log
= _SSMGateParam('A_log')
dt_bias
= _SSMGateParam('dt_bias')
nemo_automodel.components.models.qwen3_next.layers.Qwen3NextFp32GatedDeltaNet._compute_gate(
a: torch.Tensor
) -> torch.Tensor

Compute the decay gate g in fp32, via the holder when it exists.

nemo_automodel.components.models.qwen3_next.layers.Qwen3NextFp32GatedDeltaNet.forward(
hidden_states: torch.Tensor,
cache_params: typing.Any | None = None,
attention_mask: torch.Tensor | None = None
)
class nemo_automodel.components.models.qwen3_next.layers.Qwen3NextRMSNorm(
dim: int,
eps: float = 1e-06
)

Bases: Module

weight
= nn.Parameter(torch.zeros(dim))
nemo_automodel.components.models.qwen3_next.layers.Qwen3NextRMSNorm._norm(
x
)
nemo_automodel.components.models.qwen3_next.layers.Qwen3NextRMSNorm.extra_repr()
nemo_automodel.components.models.qwen3_next.layers.Qwen3NextRMSNorm.forward(
x
)
nemo_automodel.components.models.qwen3_next.layers.Qwen3NextRMSNorm.reset_parameters()
class nemo_automodel.components.models.qwen3_next.layers.Qwen3NextSSMGate(
num_v_heads: int,
dtype: torch.dtype = torch.float32
)

Bases: Module

Owns Qwen3-Next fp32 SSM-gating params and computes the decay gate.

A_log
dt_bias
nemo_automodel.components.models.qwen3_next.layers.Qwen3NextSSMGate.forward(
a: torch.Tensor
) -> torch.Tensor
class nemo_automodel.components.models.qwen3_next.layers._SSMGateParam(
name: str
)

Get-only descriptor exposing a param from _fp32_params when present.

nemo_automodel.components.models.qwen3_next.layers._SSMGateParam.__get__(
obj,
owner = None
)
nemo_automodel.components.models.qwen3_next.layers._install_ssm_gate(
mod: torch.nn.Module,
fp32_dtype: torch.dtype = torch.float32
) -> nemo_automodel.components.models.qwen3_next.layers.Qwen3NextSSMGate

Move HF-created bare A_log/dt_bias into a native fp32 holder.