nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn#
Context-Parallel-aware wrapper for Qwen3.5 MoE GatedDeltaNet linear attention.
When a CP mesh is attached (via apply_cp), the forward pass:
Recovers dense sequence order from PyTorch’s load-balanced CP layout using
seq_indexorposition_ids.Runs the causal conv1d and FLA gated delta rule on that dense ordering.
Restores the output back to the original load-balanced CP layout.
When no CP mesh is set, the module delegates to the original HF forward.
Module Contents#
Classes#
All-gather + concat with autograd-safe backward. |
|
Drop-in replacement for |
|
Holder for float32 params (A_log) that need a separate FSDP group. |
Functions#
Apply Qwen3.5 runtime patches after model construction. |
|
Create a |
|
Patch HF Qwen3.5 GatedDeltaNet modules for FSDP and optional CP support. |
API#
- nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn.apply_model_runtime_patches(model, mesh=None)#
Apply Qwen3.5 runtime patches after model construction.
The GatedDeltaNet wrapper is needed for both distributed training and single-GPU packed-sequence runs, so it must run before sharding or first forward rather than only from the FSDP parallelization strategy.
- class nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn._AllGatherConcatFn#
Bases:
torch.autograd.FunctionAll-gather + concat with autograd-safe backward.
The forward concatenates equal-sized local shards from all ranks along
dim. Backward all-reduces the concatenated gradient across ranks, then slices out the local shard for the current rank.- static forward(
- ctx,
- local_tensor: torch.Tensor,
- group: torch.distributed.ProcessGroup,
- dim: int,
- static backward(ctx, grad_output: torch.Tensor)#
- class nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn.CPAwareGatedDeltaNet(config, layer_idx: int)#
Bases:
transformers.models.qwen3_5_moe.modeling_qwen3_5_moe.Qwen3_5MoeGatedDeltaNetDrop-in replacement for
Qwen3_5MoeGatedDeltaNetwith FLA Context Parallelism.All
__init__parameters and weights are inherited unchanged from the HF class. The only addition is_cp_meshwhich is set externally byapply_cpin the parallelizer.Initialization
- _cp_mesh: torch.distributed.device_mesh.DeviceMesh | None#
None
- _compute_gate(a: torch.Tensor) torch.Tensor#
Compute the gating value
gusing fp32 params.When
_fp32_paramsexists (FSDP mixed-dtype), delegates to the holder’s forward so FSDP unshard/reshard lifecycle is natural. Otherwise falls back to the inline computation.
- _forward_no_cp(
- hidden_states: torch.Tensor,
- cache_params=None,
- cache_position=None,
- attention_mask: torch.Tensor | None = None,
- cu_seqlens: torch.Tensor | None = None,
- indices: torch.Tensor | None = None,
HF GatedDeltaNet forward with FSDP-safe fp32 gate computation.
Mirrors transformers==5.5
Qwen3_5GatedDeltaNet.forward(per-layer cache API; gate viaself._compute_gate(a)) and adds packing-aware plumbing:cu_seqlens– per-document cumulative lengths from the indexed attention mask. When supplied, FLA’s chunk kernel resets state at every document boundary.indices– non-padding token indices. When supplied AND padding is actually present (B>1 case), the layer unpads activations to[1, total_valid, ...]before conv/FLA and re-pads on the way out. For B=1 with no padding,indicescovers the whole sequence and unpadding is skipped (preserves the bit-exact fast path).
Both kwargs are produced by
Qwen3_5DecoderLayerWithPacking. As a safety net for direct callers (e.g. unit tests that bypass the decoder-layer subclass), the layer derives them fromattention_maskwhen both areNoneand the mask is indexed.
- forward(
- hidden_states: torch.Tensor,
- cache_params=None,
- cache_position=None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- qkv_format: str | None = None,
- cu_seqlens: torch.Tensor | None = None,
- indices: torch.Tensor | None = None,
- seq_index: torch.Tensor | None = None,
- _conv1d_with_cp(mixed_qkv: torch.Tensor, cp_context) torch.Tensor#
Run causal conv1d via FLA’s CP-aware conv implementation.
- Parameters:
mixed_qkv – [B, D, S_local] tensor (channels-first for conv).
cp_context – FLA CP context built by
build_cp_context.
- Returns:
[B, D, S_local] conv output with correct boundary handling.
- _extract_local_positions(
- position_ids: torch.Tensor | None,
- seq_index: torch.Tensor | None,
- seq_len: int,
- _all_gather_concat(
- tensor: torch.Tensor,
- cp_group: torch.distributed.ProcessGroup,
- *,
- dim: int,
- differentiable: bool = False,
- _undo_attention_load_balancing(
- hidden_states: torch.Tensor,
- original_positions: torch.Tensor,
- cp_group: torch.distributed.ProcessGroup,
- _redo_attention_load_balancing(
- output: torch.Tensor,
- original_positions: torch.Tensor,
- sorted_positions: torch.Tensor,
- cp_group: torch.distributed.ProcessGroup,
- _forward_with_cp(
- hidden_states: torch.Tensor,
- *,
- position_ids: torch.Tensor | None,
- seq_index: torch.Tensor | None,
- class nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn._Fp32ParamHolder#
Bases:
torch.nn.ModuleHolder for float32 params (A_log) that need a separate FSDP group.
The
forwardcomputes the gating valuegthat HF’sQwen3_5GatedDeltaNet.forwardwould normally compute inline. By doing the computation inside this module’s forward, FSDP’s unshard/reshard lifecycle works naturally — the params are unsharded during the computation and resharded after.- forward(a: torch.Tensor, dt_bias: torch.Tensor) torch.Tensor#
- nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn._make_fp32_getattr(orig_getattr)#
Create a
__getattr__that resolves fp32 params from_fp32_params.Allows
self.A_logto resolve from the holder submodule so that code outside forward (e.g. state_dict, checkpointing) can still access the parameter by name.
- nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn.patch_hf_model(model, cp_enabled=False)#
Patch HF Qwen3.5 GatedDeltaNet modules for FSDP and optional CP support.
For FSDP compatibility, move float32 bare params (A_log) into a
_fp32_paramssubmodule sofully_shard_by_dtypecan wrap them in a separate FSDP group.Every
Qwen3_5GatedDeltaNetinstance’s__class__is swapped toCPAwareGatedDeltaNetwhoseforward()callsself._fp32_params()to trigger FSDP unshard before accessing the fp32 params. Whencp_enabled=True, the CP mesh is also configured.Additionally, every
Qwen3_5DecoderLayerinstance is class-swapped toQwen3_5DecoderLayerWithPackingso that NEAT-packed sequence metadata (cu_seqlens,indices,position_ids) reacheslinear_attnvia real keyword arguments instead of relying on instance-attribute side-channels (issue #2131).