nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn#
Context-Parallel-aware wrapper for Qwen3.5 MoE GatedDeltaNet linear attention.
When a CP mesh is attached (via apply_cp), the forward pass:
Recovers dense sequence order from PyTorch’s load-balanced CP layout using
seq_indexorposition_ids.Runs the causal conv1d and FLA gated delta rule on that dense ordering.
Restores the output back to the original load-balanced CP layout.
When no CP mesh is set, the module delegates to the original HF forward.
Module Contents#
Classes#
All-gather + concat with autograd-safe backward. |
|
Drop-in replacement for |
API#
- class nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn._AllGatherConcatFn#
Bases:
torch.autograd.FunctionAll-gather + concat with autograd-safe backward.
The forward concatenates equal-sized local shards from all ranks along
dim. Backward all-reduces the concatenated gradient across ranks, then slices out the local shard for the current rank.- static forward(
- ctx,
- local_tensor: torch.Tensor,
- group: torch.distributed.ProcessGroup,
- dim: int,
- static backward(ctx, grad_output: torch.Tensor)#
- class nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn.CPAwareGatedDeltaNet(config, layer_idx: int)#
Bases:
transformers.models.qwen3_5_moe.modeling_qwen3_5_moe.Qwen3_5MoeGatedDeltaNetDrop-in replacement for
Qwen3_5MoeGatedDeltaNetwith FLA Context Parallelism.All
__init__parameters and weights are inherited unchanged from the HF class. The only addition is_cp_meshwhich is set externally byapply_cpin the parallelizer.Initialization
- _cp_mesh: torch.distributed.device_mesh.DeviceMesh | None#
None
- forward(
- hidden_states: torch.Tensor,
- cache_params=None,
- cache_position=None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- qkv_format: str | None = None,
- cu_seqlens: torch.Tensor | None = None,
- seq_index: torch.Tensor | None = None,
- _conv1d_with_cp(mixed_qkv: torch.Tensor, cp_context) torch.Tensor#
Run causal conv1d via FLA’s CP-aware conv implementation.
- Parameters:
mixed_qkv – [B, D, S_local] tensor (channels-first for conv).
cp_context – FLA CP context built by
build_cp_context.
- Returns:
[B, D, S_local] conv output with correct boundary handling.
- _extract_local_positions(
- position_ids: torch.Tensor | None,
- seq_index: torch.Tensor | None,
- seq_len: int,
- _all_gather_concat(
- tensor: torch.Tensor,
- cp_group: torch.distributed.ProcessGroup,
- *,
- dim: int,
- differentiable: bool = False,
- _undo_attention_load_balancing(
- hidden_states: torch.Tensor,
- original_positions: torch.Tensor,
- cp_group: torch.distributed.ProcessGroup,
- _redo_attention_load_balancing(
- output: torch.Tensor,
- original_positions: torch.Tensor,
- sorted_positions: torch.Tensor,
- cp_group: torch.distributed.ProcessGroup,
- _forward_with_cp(
- hidden_states: torch.Tensor,
- *,
- position_ids: torch.Tensor | None,
- seq_index: torch.Tensor | None,