nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn

View as Markdown

Context-Parallel-aware wrapper for Qwen3.5 MoE GatedDeltaNet linear attention.

When a CP mesh is attached (via apply_cp), the forward pass:

  1. Recovers dense sequence order from PyTorch’s load-balanced CP layout using a local seq_index when provided, otherwise deriving it from the CP DualChunkSwap layout.
  2. Runs the causal conv1d and FLA gated delta rule on that dense ordering.
  3. Restores the output back to the original load-balanced CP layout.

When no CP mesh is set, the module delegates to the original HF forward.

Module Contents

Classes

NameDescription
CPAwareGatedDeltaNetDrop-in replacement for Qwen3_5MoeGatedDeltaNet with FLA Context Parallelism.
SSMGateOwns the fp32 SSM-gating params (A_log/dt_bias) and computes the gate.
_AllGatherConcatFnAll-gather + concat with autograd-safe backward.
_SSMGateParamGet-only (non-data) descriptor exposing an SSMGate param as an attribute.

Functions

NameDescription
_resolve_ssm_dtypeResolve the fp32 storage dtype for the SSM-gating params from config.
install_ssm_gateMove mod’s HF-created bare A_log/dt_bias into a fp32 SSMGate.

Data

_FP32_PARAM_NAMES

API

class nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn.CPAwareGatedDeltaNet(
config,
layer_idx: int
)

Bases: Qwen3_5MoeGatedDeltaNet

Drop-in replacement for Qwen3_5MoeGatedDeltaNet with FLA Context Parallelism.

The SSM-gating params (A_log/dt_bias) are moved into a fp32 SSMGate submodule (_fp32_params) at construction so they keep fp32 storage (master weights) even under a bf16 bulk dtype, and so FSDP can shard them in their own dtype-uniform fp32 group. A_log/dt_bias remain readable as attributes via get-only descriptors that resolve to the submodule — no __getattr__ patch.

_cp_mesh is set externally by the parallelizer to enable context parallelism.

A_log
= _SSMGateParam('A_log')
_cp_mesh
DeviceMesh | None = None
dt_bias
= _SSMGateParam('dt_bias')
nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn.CPAwareGatedDeltaNet._all_gather_concat(
tensor: torch.Tensor,
cp_group: torch.distributed.ProcessGroup,
dim: int,
differentiable: bool = False
) -> torch.Tensor
nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn.CPAwareGatedDeltaNet._build_dual_chunk_local_positions(
seq_len: int,
cp_size: int,
cp_rank: int,
device: torch.device
) -> torch.Tensor
nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn.CPAwareGatedDeltaNet._compute_gate(
a: torch.Tensor
) -> torch.Tensor

Compute the gating value g via the fp32 SSMGate submodule.

Computing inside the submodule’s forward keeps FSDP’s unshard/reshard lifecycle natural for the isolated fp32 group.

nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn.CPAwareGatedDeltaNet._conv1d_with_cp(
mixed_qkv: torch.Tensor,
cp_context
) -> torch.Tensor

Run causal conv1d via FLA’s CP-aware conv implementation.

Parameters:

mixed_qkv
torch.Tensor

[B, D, S_local] tensor (channels-first for conv).

cp_context

FLA CP context built by build_cp_context.

Returns: torch.Tensor

[B, D, S_local] conv output with correct boundary handling.

nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn.CPAwareGatedDeltaNet._extract_local_seq_index(
seq_index: torch.Tensor | None,
seq_len: int
) -> torch.Tensor | None
nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn.CPAwareGatedDeltaNet._forward_no_cp(
hidden_states: torch.Tensor,
cache_params = None,
cache_position = None,
attention_mask: torch.Tensor | None = None,
cu_seqlens: torch.Tensor | None = None,
indices: torch.Tensor | None = None
)

HF GatedDeltaNet forward with FSDP-safe fp32 gate computation.

Mirrors transformers==5.5 Qwen3_5GatedDeltaNet.forward (per-layer cache API; gate via self._compute_gate(a)) and adds packing-aware plumbing:

  • cu_seqlens — per-document cumulative lengths from the indexed attention mask. When supplied, FLA’s chunk kernel resets state at every document boundary.
  • indices — non-padding token indices. When supplied AND padding is actually present (B>1 case), the layer unpads activations to [1, total_valid, ...] before conv/FLA and re-pads on the way out. For B=1 with no padding, indices covers the whole sequence and unpadding is skipped (preserves the bit-exact fast path).

Both kwargs are produced by Qwen3_5DecoderLayerWithPacking. As a safety net for direct callers (e.g. unit tests that bypass the decoder-layer subclass), the layer derives them from attention_mask when both are None and the mask is indexed.

nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn.CPAwareGatedDeltaNet._forward_with_cp(
hidden_states: torch.Tensor,
position_ids: torch.Tensor | None,
seq_index: torch.Tensor | None
) -> torch.Tensor
nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn.CPAwareGatedDeltaNet._redo_attention_load_balancing(
output: torch.Tensor,
original_positions: torch.Tensor,
sorted_positions: torch.Tensor,
cp_group: torch.distributed.ProcessGroup
) -> torch.Tensor
nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn.CPAwareGatedDeltaNet._undo_attention_load_balancing(
hidden_states: torch.Tensor,
original_positions: torch.Tensor,
cp_group: torch.distributed.ProcessGroup
) -> tuple[torch.Tensor, torch.Tensor]
nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn.CPAwareGatedDeltaNet.forward(
hidden_states: torch.Tensor,
cache_params = None,
cache_position = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
qkv_format: str | None = None,
cu_seqlens: torch.Tensor | None = None,
indices: torch.Tensor | None = None,
seq_index: torch.Tensor | None = None
)
class nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn.SSMGate(
num_v_heads: int,
dtype: torch.dtype = torch.float32
)

Bases: Module

Owns the fp32 SSM-gating params (A_log/dt_bias) and computes the gate.

Keeping these in a dedicated submodule lets FSDP shard them in their own dtype-uniform fp32 group (true master weights), and computing the gate inside forward keeps FSDP’s unshard/reshard lifecycle natural.

A_log
dt_bias
nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn.SSMGate.forward(
a: torch.Tensor
) -> torch.Tensor
class nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn._AllGatherConcatFn()

Bases: Function

All-gather + concat with autograd-safe backward.

The forward concatenates equal-sized local shards from all ranks along dim. Backward all-reduces the concatenated gradient across ranks, then slices out the local shard for the current rank.

nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn._AllGatherConcatFn.backward(
ctx,
grad_output: torch.Tensor
)
staticmethod
nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn._AllGatherConcatFn.forward(
ctx,
local_tensor: torch.Tensor,
group: torch.distributed.ProcessGroup,
dim: int
)
staticmethod
class nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn._SSMGateParam(
name: str
)

Get-only (non-data) descriptor exposing an SSMGate param as an attribute.

Lets self.A_log / self.dt_bias resolve to the fp32 SSMGate holder (self._fp32_params) without a __getattr__ monkeypatch. Being a non-data descriptor, it does not intercept assignment, so HF’s __init__ doing self.A_log = nn.Parameter(...) still routes through nn.Module.__setattr__ into _parameters (where it lives until install_ssm_gate moves it).

nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn._SSMGateParam.__get__(
obj,
owner = None
)
nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn._resolve_ssm_dtype(
config
)

Resolve the fp32 storage dtype for the SSM-gating params from config.

Honors mamba_ssm_dtype (Qwen3.5 stores A_log/dt_bias in fp32); defaults to torch.float32.

nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn.install_ssm_gate(
mod,
fp32_dtype = torch.float32
)

Move mod’s HF-created bare A_log/dt_bias into a fp32 SSMGate.

HF’s GatedDeltaNet __init__ creates A_log/dt_bias as bare params in mod._parameters. This relocates them into an :class:SSMGate submodule registered as _fp32_params (casting to fp32_dtype), so they keep fp32 storage under a bf16 bulk dtype and get their own dtype-uniform FSDP group. Attribute access (self.A_log/self.dt_bias) continues to work via the :class:_SSMGateParam descriptors on CPAwareGatedDeltaNet — no __getattr__ patch. Returns the gate submodule.

nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn._FP32_PARAM_NAMES = ('A_log', 'dt_bias')