nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn
nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn
Context-Parallel-aware wrapper for Qwen3.5 MoE GatedDeltaNet linear attention.
When a CP mesh is attached (via apply_cp), the forward pass:
- Recovers dense sequence order from PyTorch’s load-balanced CP layout using
a local
seq_indexwhen provided, otherwise deriving it from the CP DualChunkSwap layout. - Runs the causal conv1d and FLA gated delta rule on that dense ordering.
- Restores the output back to the original load-balanced CP layout.
When no CP mesh is set, the module delegates to the original HF forward.
Module Contents
Classes
Functions
Data
API
Bases: Qwen3_5MoeGatedDeltaNet
Drop-in replacement for Qwen3_5MoeGatedDeltaNet with FLA Context Parallelism.
The SSM-gating params (A_log/dt_bias) are moved into a fp32 SSMGate
submodule (_fp32_params) at construction so they keep fp32 storage (master
weights) even under a bf16 bulk dtype, and so FSDP can shard them in their own
dtype-uniform fp32 group. A_log/dt_bias remain readable as attributes via
get-only descriptors that resolve to the submodule — no __getattr__ patch.
_cp_mesh is set externally by the parallelizer to enable context parallelism.
Compute the gating value g via the fp32 SSMGate submodule.
Computing inside the submodule’s forward keeps FSDP’s unshard/reshard lifecycle natural for the isolated fp32 group.
Run causal conv1d via FLA’s CP-aware conv implementation.
Parameters:
[B, D, S_local] tensor (channels-first for conv).
FLA CP context built by build_cp_context.
Returns: torch.Tensor
[B, D, S_local] conv output with correct boundary handling.
HF GatedDeltaNet forward with FSDP-safe fp32 gate computation.
Mirrors transformers==5.5 Qwen3_5GatedDeltaNet.forward (per-layer
cache API; gate via self._compute_gate(a)) and adds packing-aware
plumbing:
cu_seqlens— per-document cumulative lengths from the indexed attention mask. When supplied, FLA’s chunk kernel resets state at every document boundary.indices— non-padding token indices. When supplied AND padding is actually present (B>1 case), the layer unpads activations to[1, total_valid, ...]before conv/FLA and re-pads on the way out. For B=1 with no padding,indicescovers the whole sequence and unpadding is skipped (preserves the bit-exact fast path).
Both kwargs are produced by Qwen3_5DecoderLayerWithPacking. As a
safety net for direct callers (e.g. unit tests that bypass the
decoder-layer subclass), the layer derives them from attention_mask
when both are None and the mask is indexed.
Bases: Module
Owns the fp32 SSM-gating params (A_log/dt_bias) and computes the gate.
Keeping these in a dedicated submodule lets FSDP shard them in their own
dtype-uniform fp32 group (true master weights), and computing the gate inside
forward keeps FSDP’s unshard/reshard lifecycle natural.
Bases: Function
All-gather + concat with autograd-safe backward.
The forward concatenates equal-sized local shards from all ranks along dim.
Backward all-reduces the concatenated gradient across ranks, then slices out
the local shard for the current rank.
Get-only (non-data) descriptor exposing an SSMGate param as an attribute.
Lets self.A_log / self.dt_bias resolve to the fp32 SSMGate holder
(self._fp32_params) without a __getattr__ monkeypatch. Being a non-data
descriptor, it does not intercept assignment, so HF’s __init__ doing
self.A_log = nn.Parameter(...) still routes through nn.Module.__setattr__
into _parameters (where it lives until install_ssm_gate moves it).
Resolve the fp32 storage dtype for the SSM-gating params from config.
Honors mamba_ssm_dtype (Qwen3.5 stores A_log/dt_bias in fp32);
defaults to torch.float32.
Move mod’s HF-created bare A_log/dt_bias into a fp32 SSMGate.
HF’s GatedDeltaNet __init__ creates A_log/dt_bias as bare params in
mod._parameters. This relocates them into an :class:SSMGate submodule
registered as _fp32_params (casting to fp32_dtype), so they keep fp32
storage under a bf16 bulk dtype and get their own dtype-uniform FSDP group.
Attribute access (self.A_log/self.dt_bias) continues to work via the
:class:_SSMGateParam descriptors on CPAwareGatedDeltaNet — no
__getattr__ patch. Returns the gate submodule.