> For clean Markdown of any page, append .md to the page URL.
> For a complete documentation index, see https://docs.nvidia.com/nemo/automodel/llms.txt.
> For AI client integration (Claude Code, Cursor, etc.), connect to the MCP server at https://docs.nvidia.com/nemo/automodel/_mcp/server.

# nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn

Context-Parallel-aware wrapper for Qwen3.5 MoE GatedDeltaNet linear attention.

When a CP mesh is attached (via `apply_cp`), the forward pass:

1. Recovers dense sequence order from PyTorch's load-balanced CP layout using
   a local `seq_index` when provided, otherwise deriving it from the CP
   DualChunkSwap layout.
2. Runs the causal conv1d and FLA gated delta rule on that dense ordering.
3. Restores the output back to the original load-balanced CP layout.

When no CP mesh is set, the module delegates to the original HF forward.

## Module Contents

### Classes

| Name                                                                                                        | Description                                                                     |
| ----------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------- |
| [`CPAwareGatedDeltaNet`](#nemo_automodel-components-models-qwen3_5_moe-cp_linear_attn-CPAwareGatedDeltaNet) | Drop-in replacement for `Qwen3_5MoeGatedDeltaNet` with FLA Context Parallelism. |
| [`SSMGate`](#nemo_automodel-components-models-qwen3_5_moe-cp_linear_attn-SSMGate)                           | Owns the fp32 SSM-gating params (`A_log`/`dt_bias`) and computes the gate.      |
| [`_AllGatherConcatFn`](#nemo_automodel-components-models-qwen3_5_moe-cp_linear_attn-_AllGatherConcatFn)     | All-gather + concat with autograd-safe backward.                                |
| [`_SSMGateParam`](#nemo_automodel-components-models-qwen3_5_moe-cp_linear_attn-_SSMGateParam)               | Get-only (non-data) descriptor exposing an `SSMGate` param as an attribute.     |

### Functions

| Name                                                                                                    | Description                                                             |
| ------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------- |
| [`_resolve_ssm_dtype`](#nemo_automodel-components-models-qwen3_5_moe-cp_linear_attn-_resolve_ssm_dtype) | Resolve the fp32 storage dtype for the SSM-gating params from `config`. |
| [`install_ssm_gate`](#nemo_automodel-components-models-qwen3_5_moe-cp_linear_attn-install_ssm_gate)     | Move `mod`'s HF-created bare `A_log`/`dt_bias` into a fp32 `SSMGate`.   |

### Data

[`_FP32_PARAM_NAMES`](#nemo_automodel-components-models-qwen3_5_moe-cp_linear_attn-_FP32_PARAM_NAMES)

### API

```python
class nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn.CPAwareGatedDeltaNet(
    config,
    layer_idx: int
)
```

**Bases:** `Qwen3_5MoeGatedDeltaNet`

Drop-in replacement for `Qwen3_5MoeGatedDeltaNet` with FLA Context Parallelism.

The SSM-gating params (`A_log`/`dt_bias`) are moved into a fp32 `SSMGate`
submodule (`_fp32_params`) at construction so they keep fp32 storage (master
weights) even under a bf16 bulk dtype, and so FSDP can shard them in their own
dtype-uniform fp32 group. `A_log`/`dt_bias` remain readable as attributes via
get-only descriptors that resolve to the submodule — no `__getattr__` patch.

`_cp_mesh` is set externally by the parallelizer to enable context parallelism.

```python
nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn.CPAwareGatedDeltaNet._all_gather_concat(
    tensor: torch.Tensor,
    cp_group: torch.distributed.ProcessGroup,
    dim: int,
    differentiable: bool = False
) -> torch.Tensor
```

```python
nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn.CPAwareGatedDeltaNet._build_dual_chunk_local_positions(
    seq_len: int,
    cp_size: int,
    cp_rank: int,
    device: torch.device
) -> torch.Tensor
```

```python
nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn.CPAwareGatedDeltaNet._compute_gate(
    a: torch.Tensor
) -> torch.Tensor
```

Compute the gating value `g` via the fp32 `SSMGate` submodule.

Computing inside the submodule's forward keeps FSDP's unshard/reshard
lifecycle natural for the isolated fp32 group.

```python
nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn.CPAwareGatedDeltaNet._conv1d_with_cp(
    mixed_qkv: torch.Tensor,
    cp_context
) -> torch.Tensor
```

Run causal conv1d via FLA's CP-aware conv implementation.

**Parameters:**

\[B, D, S\_local] tensor (channels-first for conv).

FLA CP context built by `build_cp_context`.

**Returns:** `torch.Tensor`

\[B, D, S\_local] conv output with correct boundary handling.

```python
nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn.CPAwareGatedDeltaNet._extract_local_seq_index(
    seq_index: torch.Tensor | None,
    seq_len: int
) -> torch.Tensor | None
```

```python
nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn.CPAwareGatedDeltaNet._forward_no_cp(
    hidden_states: torch.Tensor,
    cache_params = None,
    cache_position = None,
    attention_mask: torch.Tensor | None = None,
    cu_seqlens: torch.Tensor | None = None,
    indices: torch.Tensor | None = None
)
```

HF GatedDeltaNet forward with FSDP-safe fp32 gate computation.

Mirrors transformers==5.5 `Qwen3_5GatedDeltaNet.forward` (per-layer
cache API; gate via `self._compute_gate(a)`) and adds packing-aware
plumbing:

* `cu_seqlens` -- per-document cumulative lengths from the indexed
  attention mask. When supplied, FLA's chunk kernel resets state at
  every document boundary.
* `indices` -- non-padding token indices. When supplied AND padding
  is actually present (B>1 case), the layer unpads activations to
  `[1, total_valid, ...]` before conv/FLA and re-pads on the way
  out. For B=1 with no padding, `indices` covers the whole sequence
  and unpadding is skipped (preserves the bit-exact fast path).

Both kwargs are produced by `Qwen3_5DecoderLayerWithPacking`. As a
safety net for direct callers (e.g. unit tests that bypass the
decoder-layer subclass), the layer derives them from `attention_mask`
when both are `None` and the mask is indexed.

```python
nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn.CPAwareGatedDeltaNet._forward_with_cp(
    hidden_states: torch.Tensor,
    position_ids: torch.Tensor | None,
    seq_index: torch.Tensor | None
) -> torch.Tensor
```

```python
nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn.CPAwareGatedDeltaNet._redo_attention_load_balancing(
    output: torch.Tensor,
    original_positions: torch.Tensor,
    sorted_positions: torch.Tensor,
    cp_group: torch.distributed.ProcessGroup
) -> torch.Tensor
```

```python
nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn.CPAwareGatedDeltaNet._undo_attention_load_balancing(
    hidden_states: torch.Tensor,
    original_positions: torch.Tensor,
    cp_group: torch.distributed.ProcessGroup
) -> tuple[torch.Tensor, torch.Tensor]
```

```python
nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn.CPAwareGatedDeltaNet.forward(
    hidden_states: torch.Tensor,
    cache_params = None,
    cache_position = None,
    attention_mask: torch.Tensor | None = None,
    position_ids: torch.Tensor | None = None,
    qkv_format: str | None = None,
    cu_seqlens: torch.Tensor | None = None,
    indices: torch.Tensor | None = None,
    seq_index: torch.Tensor | None = None
)
```

```python
class nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn.SSMGate(
    num_v_heads: int,
    dtype: torch.dtype = torch.float32
)
```

**Bases:** `Module`

Owns the fp32 SSM-gating params (`A_log`/`dt_bias`) and computes the gate.

Keeping these in a dedicated submodule lets FSDP shard them in their own
dtype-uniform fp32 group (true master weights), and computing the gate inside
`forward` keeps FSDP's unshard/reshard lifecycle natural.

```python
nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn.SSMGate.forward(
    a: torch.Tensor
) -> torch.Tensor
```

```python
class nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn._AllGatherConcatFn()
```

**Bases:** `Function`

All-gather + concat with autograd-safe backward.

The forward concatenates equal-sized local shards from all ranks along `dim`.
Backward all-reduces the concatenated gradient across ranks, then slices out
the local shard for the current rank.

```python
nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn._AllGatherConcatFn.backward(
    ctx,
    grad_output: torch.Tensor
)
```

staticmethod

```python
nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn._AllGatherConcatFn.forward(
    ctx,
    local_tensor: torch.Tensor,
    group: torch.distributed.ProcessGroup,
    dim: int
)
```

staticmethod

```python
class nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn._SSMGateParam(
    name: str
)
```

Get-only (non-data) descriptor exposing an `SSMGate` param as an attribute.

Lets `self.A_log` / `self.dt_bias` resolve to the fp32 `SSMGate` holder
(`self._fp32_params`) without a `__getattr__` monkeypatch. Being a non-data
descriptor, it does not intercept assignment, so HF's `__init__` doing
`self.A_log = nn.Parameter(...)` still routes through `nn.Module.__setattr__`
into `_parameters` (where it lives until `install_ssm_gate` moves it).

```python
nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn._SSMGateParam.__get__(
    obj,
    owner = None
)
```

```python
nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn._resolve_ssm_dtype(
    config
)
```

Resolve the fp32 storage dtype for the SSM-gating params from `config`.

Honors `mamba_ssm_dtype` (Qwen3.5 stores `A_log`/`dt_bias` in fp32);
defaults to `torch.float32`.

```python
nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn.install_ssm_gate(
    mod,
    fp32_dtype = torch.float32
)
```

Move `mod`'s HF-created bare `A_log`/`dt_bias` into a fp32 `SSMGate`.

HF's GatedDeltaNet `__init__` creates `A_log`/`dt_bias` as bare params in
`mod._parameters`. This relocates them into an :class:`SSMGate` submodule
registered as `_fp32_params` (casting to `fp32_dtype`), so they keep fp32
storage under a bf16 bulk dtype and get their own dtype-uniform FSDP group.
Attribute access (`self.A_log`/`self.dt_bias`) continues to work via the
:class:`_SSMGateParam` descriptors on `CPAwareGatedDeltaNet` — no
`__getattr__` patch. Returns the gate submodule.

```python
nemo_automodel.components.models.qwen3_5_moe.cp_linear_attn._FP32_PARAM_NAMES = ('A_log', 'dt_bias')
```