nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn#
Context-Parallel-aware wrapper for Qwen3.5 MoE GatedDeltaNet linear attention.
When a CP mesh is attached (via apply_cp), the forward pass:
Recovers dense sequence order from PyTorch’s load-balanced CP layout using
seq_indexorposition_ids.Runs the causal conv1d and FLA gated delta rule on that dense ordering.
Restores the output back to the original load-balanced CP layout.
When no CP mesh is set, the module delegates to the original HF forward.
Module Contents#
Classes#
All-gather + concat with autograd-safe backward. |
|
Drop-in replacement for |
|
Holder for float32 params (A_log) that need a separate FSDP group. |
Functions#
Create a |
|
Patch HF Qwen3.5 GatedDeltaNet modules for FSDP and optional CP support. |
API#
- class nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn._AllGatherConcatFn#
Bases:
torch.autograd.FunctionAll-gather + concat with autograd-safe backward.
The forward concatenates equal-sized local shards from all ranks along
dim. Backward all-reduces the concatenated gradient across ranks, then slices out the local shard for the current rank.- static forward(
- ctx,
- local_tensor: torch.Tensor,
- group: torch.distributed.ProcessGroup,
- dim: int,
- static backward(ctx, grad_output: torch.Tensor)#
- class nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn.CPAwareGatedDeltaNet(config, layer_idx: int)#
Bases:
transformers.models.qwen3_5_moe.modeling_qwen3_5_moe.Qwen3_5MoeGatedDeltaNetDrop-in replacement for
Qwen3_5MoeGatedDeltaNetwith FLA Context Parallelism.All
__init__parameters and weights are inherited unchanged from the HF class. The only addition is_cp_meshwhich is set externally byapply_cpin the parallelizer.Initialization
- _cp_mesh: torch.distributed.device_mesh.DeviceMesh | None#
None
- _compute_gate(a: torch.Tensor) torch.Tensor#
Compute the gating value
gusing fp32 params.When
_fp32_paramsexists (FSDP mixed-dtype), delegates to the holder’s forward so FSDP unshard/reshard lifecycle is natural. Otherwise falls back to the inline computation.
- _forward_no_cp(
- hidden_states: torch.Tensor,
- cache_params=None,
- cache_position=None,
- attention_mask: torch.Tensor | None = None,
HF GatedDeltaNet forward with FSDP-safe fp32 gate computation.
Mirrors transformers==5.5 Qwen3_5GatedDeltaNet.forward (uses the per-layer cache API:
has_previous_state(layer_idx),cache_params.layers[layer_idx].{conv,recurrent}_states, and theupdate_{conv,recurrent}_statemethods) with the gate computation replaced byself._compute_gate(a).
- forward(
- hidden_states: torch.Tensor,
- cache_params=None,
- cache_position=None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- qkv_format: str | None = None,
- cu_seqlens: torch.Tensor | None = None,
- seq_index: torch.Tensor | None = None,
- _conv1d_with_cp(mixed_qkv: torch.Tensor, cp_context) torch.Tensor#
Run causal conv1d via FLA’s CP-aware conv implementation.
- Parameters:
mixed_qkv – [B, D, S_local] tensor (channels-first for conv).
cp_context – FLA CP context built by
build_cp_context.
- Returns:
[B, D, S_local] conv output with correct boundary handling.
- _extract_local_positions(
- position_ids: torch.Tensor | None,
- seq_index: torch.Tensor | None,
- seq_len: int,
- _all_gather_concat(
- tensor: torch.Tensor,
- cp_group: torch.distributed.ProcessGroup,
- *,
- dim: int,
- differentiable: bool = False,
- _undo_attention_load_balancing(
- hidden_states: torch.Tensor,
- original_positions: torch.Tensor,
- cp_group: torch.distributed.ProcessGroup,
- _redo_attention_load_balancing(
- output: torch.Tensor,
- original_positions: torch.Tensor,
- sorted_positions: torch.Tensor,
- cp_group: torch.distributed.ProcessGroup,
- _forward_with_cp(
- hidden_states: torch.Tensor,
- *,
- position_ids: torch.Tensor | None,
- seq_index: torch.Tensor | None,
- class nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn._Fp32ParamHolder#
Bases:
torch.nn.ModuleHolder for float32 params (A_log) that need a separate FSDP group.
The
forwardcomputes the gating valuegthat HF’sQwen3_5GatedDeltaNet.forwardwould normally compute inline. By doing the computation inside this module’s forward, FSDP’s unshard/reshard lifecycle works naturally — the params are unsharded during the computation and resharded after.- forward(a: torch.Tensor, dt_bias: torch.Tensor) torch.Tensor#
- nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn._make_fp32_getattr(orig_getattr)#
Create a
__getattr__that resolves fp32 params from_fp32_params.Allows
self.A_logto resolve from the holder submodule so that code outside forward (e.g. state_dict, checkpointing) can still access the parameter by name.
- nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn.patch_hf_model(model, cp_enabled=False)#
Patch HF Qwen3.5 GatedDeltaNet modules for FSDP and optional CP support.
For FSDP compatibility, move float32 bare params (A_log) into a
_fp32_paramssubmodule sofully_shard_by_dtypecan wrap them in a separate FSDP group.Every module’s
__class__is swapped toCPAwareGatedDeltaNetwhoseforward()callsself._fp32_params()to trigger FSDP unshard before accessing the fp32 params. Whencp_enabled=True, the CP mesh is also configured.