nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn#

Context-Parallel-aware wrapper for Qwen3.5 MoE GatedDeltaNet linear attention.

When a CP mesh is attached (via apply_cp), the forward pass:

  1. Recovers dense sequence order from PyTorch’s load-balanced CP layout using seq_index or position_ids.

  2. Runs the causal conv1d and FLA gated delta rule on that dense ordering.

  3. Restores the output back to the original load-balanced CP layout.

When no CP mesh is set, the module delegates to the original HF forward.

Module Contents#

Classes#

_AllGatherConcatFn

All-gather + concat with autograd-safe backward.

CPAwareGatedDeltaNet

Drop-in replacement for Qwen3_5MoeGatedDeltaNet with FLA Context Parallelism.

API#

class nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn._AllGatherConcatFn#

Bases: torch.autograd.Function

All-gather + concat with autograd-safe backward.

The forward concatenates equal-sized local shards from all ranks along dim. Backward all-reduces the concatenated gradient across ranks, then slices out the local shard for the current rank.

static forward(
ctx,
local_tensor: torch.Tensor,
group: torch.distributed.ProcessGroup,
dim: int,
)#
static backward(ctx, grad_output: torch.Tensor)#
class nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn.CPAwareGatedDeltaNet(config, layer_idx: int)#

Bases: transformers.models.qwen3_5_moe.modeling_qwen3_5_moe.Qwen3_5MoeGatedDeltaNet

Drop-in replacement for Qwen3_5MoeGatedDeltaNet with FLA Context Parallelism.

All __init__ parameters and weights are inherited unchanged from the HF class. The only addition is _cp_mesh which is set externally by apply_cp in the parallelizer.

Initialization

_cp_mesh: torch.distributed.device_mesh.DeviceMesh | None#

None

forward(
hidden_states: torch.Tensor,
cache_params=None,
cache_position=None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
qkv_format: str | None = None,
cu_seqlens: torch.Tensor | None = None,
seq_index: torch.Tensor | None = None,
)#
_conv1d_with_cp(mixed_qkv: torch.Tensor, cp_context) torch.Tensor#

Run causal conv1d via FLA’s CP-aware conv implementation.

Parameters:
  • mixed_qkv – [B, D, S_local] tensor (channels-first for conv).

  • cp_context – FLA CP context built by build_cp_context.

Returns:

[B, D, S_local] conv output with correct boundary handling.

_extract_local_positions(
position_ids: torch.Tensor | None,
seq_index: torch.Tensor | None,
seq_len: int,
) torch.Tensor | None#
_all_gather_concat(
tensor: torch.Tensor,
cp_group: torch.distributed.ProcessGroup,
*,
dim: int,
differentiable: bool = False,
) torch.Tensor#
_undo_attention_load_balancing(
hidden_states: torch.Tensor,
original_positions: torch.Tensor,
cp_group: torch.distributed.ProcessGroup,
) tuple[torch.Tensor, torch.Tensor]#
_redo_attention_load_balancing(
output: torch.Tensor,
original_positions: torch.Tensor,
sorted_positions: torch.Tensor,
cp_group: torch.distributed.ProcessGroup,
) torch.Tensor#
_forward_with_cp(
hidden_states: torch.Tensor,
*,
position_ids: torch.Tensor | None,
seq_index: torch.Tensor | None,
) torch.Tensor#