nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn#

Context-Parallel-aware wrapper for Qwen3.5 MoE GatedDeltaNet linear attention.

When a CP mesh is attached (via apply_cp), the forward pass:

  1. Recovers dense sequence order from PyTorch’s load-balanced CP layout using seq_index or position_ids.

  2. Runs the causal conv1d and FLA gated delta rule on that dense ordering.

  3. Restores the output back to the original load-balanced CP layout.

When no CP mesh is set, the module delegates to the original HF forward.

Module Contents#

Classes#

_AllGatherConcatFn

All-gather + concat with autograd-safe backward.

CPAwareGatedDeltaNet

Drop-in replacement for Qwen3_5MoeGatedDeltaNet with FLA Context Parallelism.

_Fp32ParamHolder

Holder for float32 params (A_log) that need a separate FSDP group.

Functions#

_make_fp32_getattr

Create a __getattr__ that resolves fp32 params from _fp32_params.

patch_hf_model

Patch HF Qwen3.5 GatedDeltaNet modules for FSDP and optional CP support.

API#

class nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn._AllGatherConcatFn#

Bases: torch.autograd.Function

All-gather + concat with autograd-safe backward.

The forward concatenates equal-sized local shards from all ranks along dim. Backward all-reduces the concatenated gradient across ranks, then slices out the local shard for the current rank.

static forward(
ctx,
local_tensor: torch.Tensor,
group: torch.distributed.ProcessGroup,
dim: int,
)#
static backward(ctx, grad_output: torch.Tensor)#
class nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn.CPAwareGatedDeltaNet(config, layer_idx: int)#

Bases: transformers.models.qwen3_5_moe.modeling_qwen3_5_moe.Qwen3_5MoeGatedDeltaNet

Drop-in replacement for Qwen3_5MoeGatedDeltaNet with FLA Context Parallelism.

All __init__ parameters and weights are inherited unchanged from the HF class. The only addition is _cp_mesh which is set externally by apply_cp in the parallelizer.

Initialization

_cp_mesh: torch.distributed.device_mesh.DeviceMesh | None#

None

_compute_gate(a: torch.Tensor) torch.Tensor#

Compute the gating value g using fp32 params.

When _fp32_params exists (FSDP mixed-dtype), delegates to the holder’s forward so FSDP unshard/reshard lifecycle is natural. Otherwise falls back to the inline computation.

_forward_no_cp(
hidden_states: torch.Tensor,
cache_params=None,
cache_position=None,
attention_mask: torch.Tensor | None = None,
)#

HF GatedDeltaNet forward with FSDP-safe fp32 gate computation.

Mirrors transformers==5.5 Qwen3_5GatedDeltaNet.forward (uses the per-layer cache API: has_previous_state(layer_idx), cache_params.layers[layer_idx].{conv,recurrent}_states, and the update_{conv,recurrent}_state methods) with the gate computation replaced by self._compute_gate(a).

forward(
hidden_states: torch.Tensor,
cache_params=None,
cache_position=None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
qkv_format: str | None = None,
cu_seqlens: torch.Tensor | None = None,
seq_index: torch.Tensor | None = None,
)#
_conv1d_with_cp(mixed_qkv: torch.Tensor, cp_context) torch.Tensor#

Run causal conv1d via FLA’s CP-aware conv implementation.

Parameters:
  • mixed_qkv – [B, D, S_local] tensor (channels-first for conv).

  • cp_context – FLA CP context built by build_cp_context.

Returns:

[B, D, S_local] conv output with correct boundary handling.

_extract_local_positions(
position_ids: torch.Tensor | None,
seq_index: torch.Tensor | None,
seq_len: int,
) torch.Tensor | None#
_all_gather_concat(
tensor: torch.Tensor,
cp_group: torch.distributed.ProcessGroup,
*,
dim: int,
differentiable: bool = False,
) torch.Tensor#
_undo_attention_load_balancing(
hidden_states: torch.Tensor,
original_positions: torch.Tensor,
cp_group: torch.distributed.ProcessGroup,
) tuple[torch.Tensor, torch.Tensor]#
_redo_attention_load_balancing(
output: torch.Tensor,
original_positions: torch.Tensor,
sorted_positions: torch.Tensor,
cp_group: torch.distributed.ProcessGroup,
) torch.Tensor#
_forward_with_cp(
hidden_states: torch.Tensor,
*,
position_ids: torch.Tensor | None,
seq_index: torch.Tensor | None,
) torch.Tensor#
class nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn._Fp32ParamHolder#

Bases: torch.nn.Module

Holder for float32 params (A_log) that need a separate FSDP group.

The forward computes the gating value g that HF’s Qwen3_5GatedDeltaNet.forward would normally compute inline. By doing the computation inside this module’s forward, FSDP’s unshard/reshard lifecycle works naturally — the params are unsharded during the computation and resharded after.

forward(a: torch.Tensor, dt_bias: torch.Tensor) torch.Tensor#
nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn._make_fp32_getattr(orig_getattr)#

Create a __getattr__ that resolves fp32 params from _fp32_params.

Allows self.A_log to resolve from the holder submodule so that code outside forward (e.g. state_dict, checkpointing) can still access the parameter by name.

nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn.patch_hf_model(model, cp_enabled=False)#

Patch HF Qwen3.5 GatedDeltaNet modules for FSDP and optional CP support.

For FSDP compatibility, move float32 bare params (A_log) into a _fp32_params submodule so fully_shard_by_dtype can wrap them in a separate FSDP group.

Every module’s __class__ is swapped to CPAwareGatedDeltaNet whose forward() calls self._fp32_params() to trigger FSDP unshard before accessing the fp32 params. When cp_enabled=True, the CP mesh is also configured.