nemo_automodel.components.models.qwen3_next.layers
nemo_automodel.components.models.qwen3_next.layers
Module Contents
Classes
Functions
API
Bases: Module
Bases: Qwen3NextGatedDeltaNet
Qwen3-Next GatedDeltaNet that computes the decay gate via an fp32 holder.
HF’s Qwen3NextGatedDeltaNet computes the gate inline as
g = -exp(A_log) * softplus(a + dt_bias) using the bare A_log / dt_bias
parameters. A_log and dt_bias are intrinsically fp32 (A_log is
exponentiated, so bf16 rounding becomes a proportional error on the decay rate that
the recurrence compounds across the sequence).
The constructor moves those params into a native _fp32_params holder so they
are fp32 resident before any dtype cast or FSDP wrapping. To keep the gate
computation in fp32 — and to make FSDP’s unshard/reshard + gradient
reduce-scatter fire for that unit — the gate is computed inside the holder’s
forward. This subclass overrides forward to route the gate through
self._compute_gate(a) while reproducing the rest of HF’s forward verbatim.
Compute the decay gate g in fp32, via the holder when it exists.
Bases: Module
Bases: Module
Owns Qwen3-Next fp32 SSM-gating params and computes the decay gate.
Get-only descriptor exposing a param from _fp32_params when present.
Move HF-created bare A_log/dt_bias into a native fp32 holder.