> For clean Markdown of any page, append .md to the page URL.
> For a complete documentation index, see https://docs.nvidia.com/nemo/automodel/llms.txt.
> For AI client integration (Claude Code, Cursor, etc.), connect to the MCP server at https://docs.nvidia.com/nemo/automodel/_mcp/server.

# nemo_automodel.components.models.gemma4_moe.cp_attention

Gemma4-specific context-parallel attention helpers.

## Module Contents

### Classes

| Name                                                                                                             | Description                                                                                |
| ---------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------ |
| [`CPRingAttentionContext`](#nemo_automodel-components-models-gemma4_moe-cp_attention-CPRingAttentionContext)     | Inputs for Gemma4 manual ring CP attention (built by the run\_cp\_manual\_attention seam). |
| [`_Gemma4FlexRingAttention`](#nemo_automodel-components-models-gemma4_moe-cp_attention-_Gemma4FlexRingAttention) | -                                                                                          |

### Functions

| Name                                                                                                                                       | Description                                                                       |
| ------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------- |
| [`_base_gemma4_cp_mask`](#nemo_automodel-components-models-gemma4_moe-cp_attention-_base_gemma4_cp_mask)                                   | -                                                                                 |
| [`_block_mask_set_generation`](#nemo_automodel-components-models-gemma4_moe-cp_attention-_block_mask_set_generation)                       | Reset the per-step block-mask cache when a new batch (new metadata) arrives.      |
| [`_cached_block_mask`](#nemo_automodel-components-models-gemma4_moe-cp_attention-_cached_block_mask)                                       | -                                                                                 |
| [`_collect_ring_kv_chunks`](#nemo_automodel-components-models-gemma4_moe-cp_attention-_collect_ring_kv_chunks)                             | -                                                                                 |
| [`_compiled_flex_attention`](#nemo_automodel-components-models-gemma4_moe-cp_attention-_compiled_flex_attention)                           | -                                                                                 |
| [`_detach_metadata`](#nemo_automodel-components-models-gemma4_moe-cp_attention-_detach_metadata)                                           | -                                                                                 |
| [`_direct_exchange`](#nemo_automodel-components-models-gemma4_moe-cp_attention-_direct_exchange)                                           | -                                                                                 |
| [`_duck_shape_disabled`](#nemo_automodel-components-models-gemma4_moe-cp_attention-_duck_shape_disabled)                                   | Locally disable flex duck-shape specialization for the wrapped flex call.         |
| [`_gemma4_cp_manual_attention`](#nemo_automodel-components-models-gemma4_moe-cp_attention-_gemma4_cp_manual_attention)                     | Gemma4-owned manual ring CP attention entry.                                      |
| [`_install_gemma4_cp_ring_sdpa`](#nemo_automodel-components-models-gemma4_moe-cp_attention-_install_gemma4_cp_ring_sdpa)                   | Swap `F.scaled_dot_product_attention` -> Gemma4 ring CP attention on this module. |
| [`_merge_flex_chunk`](#nemo_automodel-components-models-gemma4_moe-cp_attention-_merge_flex_chunk)                                         | -                                                                                 |
| [`_metadata_like`](#nemo_automodel-components-models-gemma4_moe-cp_attention-_metadata_like)                                               | -                                                                                 |
| [`_patch_fsdp_accumulated_grad_guard`](#nemo_automodel-components-models-gemma4_moe-cp_attention-_patch_fsdp_accumulated_grad_guard)       | Guard `FSDPParam.to_accumulated_grad_if_needed` against uninitialized params.     |
| [`_ring_exchange`](#nemo_automodel-components-models-gemma4_moe-cp_attention-_ring_exchange)                                               | -                                                                                 |
| [`_run_gemma4_cp_ring_attention`](#nemo_automodel-components-models-gemma4_moe-cp_attention-_run_gemma4_cp_ring_attention)                 | Run Gemma4 local-query/ring-key CP attention with FlexAttention.                  |
| [`_run_gemma4_cp_ring_attention_forward`](#nemo_automodel-components-models-gemma4_moe-cp_attention-_run_gemma4_cp_ring_attention_forward) | Run Gemma4 local-query/ring-key CP attention forward with FlexAttention.          |
| [`_run_gemma4_flex_chunk`](#nemo_automodel-components-models-gemma4_moe-cp_attention-_run_gemma4_flex_chunk)                               | -                                                                                 |
| [`_zero_if_none`](#nemo_automodel-components-models-gemma4_moe-cp_attention-_zero_if_none)                                                 | -                                                                                 |
| [`attach_gemma4_cp_ring_attention`](#nemo_automodel-components-models-gemma4_moe-cp_attention-attach_gemma4_cp_ring_attention)             | Register Gemma4's model-owned p2p ring CP attention on a self-attention module.   |
| [`gemma4_vision_group_ids`](#nemo_automodel-components-models-gemma4_moe-cp_attention-gemma4_vision_group_ids)                             | Return per-image-block ids for Gemma4 vision tokens, or -1 for text/padding.      |

### Data

[`_BLOCK_MASK_CACHE`](#nemo_automodel-components-models-gemma4_moe-cp_attention-_BLOCK_MASK_CACHE)

[`_BLOCK_MASK_GEN`](#nemo_automodel-components-models-gemma4_moe-cp_attention-_BLOCK_MASK_GEN)

[`_GEMMA4_CP_FLEX_RING_OK_LOGGED`](#nemo_automodel-components-models-gemma4_moe-cp_attention-_GEMMA4_CP_FLEX_RING_OK_LOGGED)

[`logger`](#nemo_automodel-components-models-gemma4_moe-cp_attention-logger)

### API

```python
class nemo_automodel.components.models.gemma4_moe.cp_attention.CPRingAttentionContext(
    module: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    cp_mesh: typing.Any,
    cp_group: typing.Any,
    cp_size: int,
    cp_rank: int,
    seq_local: int,
    seq_full: int,
    seq_global_start: int,
    attn_mask: typing.Any,
    dropout_p: float,
    is_causal: bool,
    scale: typing.Any,
    enable_gqa: bool,
    kwargs: dict[str, typing.Any],
    metadata: dict[str, torch.Tensor | None],
    metadata_seq_dims: dict[str, int]
)
```

Dataclass

Inputs for Gemma4 manual ring CP attention (built by the run\_cp\_manual\_attention seam).

```python
class nemo_automodel.components.models.gemma4_moe.cp_attention._Gemma4FlexRingAttention()
```

**Bases:** `Function`

```python
nemo_automodel.components.models.gemma4_moe.cp_attention._Gemma4FlexRingAttention.backward(
    autograd_ctx,
    grad_output: torch.Tensor
)
```

staticmethod

```python
nemo_automodel.components.models.gemma4_moe.cp_attention._Gemma4FlexRingAttention.forward(
    autograd_ctx,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    ring_ctx: typing.Any
)
```

staticmethod

```python
nemo_automodel.components.models.gemma4_moe.cp_attention._base_gemma4_cp_mask(
    attention_module: torch.nn.Module,
    ctx: typing.Any,
    q_idx,
    kv_idx,
    kv_global_start: int = 0
)
```

```python
nemo_automodel.components.models.gemma4_moe.cp_attention._block_mask_set_generation(
    gen_tensor
) -> None
```

Reset the per-step block-mask cache when a new batch (new metadata) arrives.

```python
nemo_automodel.components.models.gemma4_moe.cp_attention._cached_block_mask(
    key,
    build
)
```

```python
nemo_automodel.components.models.gemma4_moe.cp_attention._collect_ring_kv_chunks(
    ctx: typing.Any
) -> list[tuple[int, torch.Tensor, torch.Tensor, dict[str, torch.Tensor | None]]]
```

```python
nemo_automodel.components.models.gemma4_moe.cp_attention._compiled_flex_attention(
    attention_module: torch.nn.Module
)
```

```python
nemo_automodel.components.models.gemma4_moe.cp_attention._detach_metadata(
    metadata: dict[str, torch.Tensor | None]
) -> dict[str, torch.Tensor | None]
```

```python
nemo_automodel.components.models.gemma4_moe.cp_attention._direct_exchange(
    tensors: list[tuple[torch.Tensor, torch.Tensor]],
    cp_group: typing.Any,
    cp_rank: int,
    send_cp_rank: int,
    recv_cp_rank: int
) -> None
```

```python
nemo_automodel.components.models.gemma4_moe.cp_attention._duck_shape_disabled()
```

Locally disable flex duck-shape specialization for the wrapped flex call.

With variable-length (unpacked) batches the compiled flex kernel otherwise
guards on incidental dim-equalities (e.g. `block_mask.kv_indices.size()[2] ==
key.size()[1]`) and recompiles on every new sequence length, collapsing
throughput to \~warmup speed. `use_duck_shape` is read by dynamo at (re)trace
time -- which happens inside the flex call -- so scoping it to the call window
is sufficient and, unlike setting it once at compile time, does not leave the
process-global `torch.fx` config mutated for unrelated `torch.compile` users.

```python
nemo_automodel.components.models.gemma4_moe.cp_attention._gemma4_cp_manual_attention(
    attention_module: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    cp_mesh,
    attn_mask,
    dropout_p,
    is_causal,
    scale,
    enable_gqa,
    kwargs
) -> torch.Tensor
```

Gemma4-owned manual ring CP attention entry.

Plugs into cp\_utils' generic `run_cp_manual_attention` seam: receives the
raw local (un-gathered) Q/K/V plus `cp_mesh`, builds the ring context, and
runs the p2p ring FlexAttention. K/V are rotated across CP ranks inside the
ring autograd function -- they are never all-gathered.

```python
nemo_automodel.components.models.gemma4_moe.cp_attention._install_gemma4_cp_ring_sdpa(
    attention_module: torch.nn.Module,
    cp_mesh
) -> None
```

Swap `F.scaled_dot_product_attention` -> Gemma4 ring CP attention on this module.

Gemma4 owns its CP attention end-to-end (it does not use cp\_utils' generic CP
SDPA hooks). It installs its own `@torch._dynamo.disable` SDPA wrapper -- on
the inner attention module so it also fires during gradient-checkpointing
recompute -- that runs the p2p ring FlexAttention. The per-forward attention
kwargs the ring needs (mm\_token\_type\_ids, packed-seq ids, padding/vision masks)
are captured off the forward kwargs into `_cp_manual_metadata` here, since the
swapped SDPA only receives Q/K/V.

```python
nemo_automodel.components.models.gemma4_moe.cp_attention._merge_flex_chunk(
    out_acc: torch.Tensor | None,
    lse_acc: torch.Tensor | None,
    out_step: torch.Tensor,
    lse_step: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]
```

```python
nemo_automodel.components.models.gemma4_moe.cp_attention._metadata_like(
    metadata: dict[str, torch.Tensor | None]
) -> dict[str, torch.Tensor | None]
```

```python
nemo_automodel.components.models.gemma4_moe.cp_attention._patch_fsdp_accumulated_grad_guard() -> None
```

Guard `FSDPParam.to_accumulated_grad_if_needed` against uninitialized params.

On some torch builds that method reads `self._unsharded_param` (the lazily
set unsharded tensor) without first checking it exists. In FSDP2 post-backward
under fp32 grad-reduce, frozen / never-unsharded params (e.g. the frozen Gemma4
vision tower and embeddings) have no `_unsharded_param` yet and it raises
`AttributeError`. Such params carry no grad to upcast anyway, so wrap the
method to skip them when uninitialized. No-op once applied / on fixed builds.

```python
nemo_automodel.components.models.gemma4_moe.cp_attention._ring_exchange(
    tensors: list[tuple[torch.Tensor, torch.Tensor]],
    cp_group: typing.Any,
    cp_rank: int,
    cp_size: int
) -> None
```

```python
nemo_automodel.components.models.gemma4_moe.cp_attention._run_gemma4_cp_ring_attention(
    attention_module: torch.nn.Module,
    ctx: typing.Any
) -> torch.Tensor
```

Run Gemma4 local-query/ring-key CP attention with FlexAttention.

```python
nemo_automodel.components.models.gemma4_moe.cp_attention._run_gemma4_cp_ring_attention_forward(
    attention_module: torch.nn.Module,
    ctx: typing.Any
) -> torch.Tensor
```

Run Gemma4 local-query/ring-key CP attention forward with FlexAttention.

```python
nemo_automodel.components.models.gemma4_moe.cp_attention._run_gemma4_flex_chunk(
    attention_module: torch.nn.Module,
    ctx: typing.Any,
    key_chunk: torch.Tensor,
    value_chunk: torch.Tensor,
    metadata_chunk: dict[str, torch.Tensor | None],
    kv_global_start: int
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, int]
```

```python
nemo_automodel.components.models.gemma4_moe.cp_attention._zero_if_none(
    grad: torch.Tensor | None,
    like: torch.Tensor
) -> torch.Tensor
```

```python
nemo_automodel.components.models.gemma4_moe.cp_attention.attach_gemma4_cp_ring_attention(
    attention_module: torch.nn.Module
) -> None
```

Register Gemma4's model-owned p2p ring CP attention on a self-attention module.

Declares the metadata keys the ring needs and exposes `setup_cp_attention(cp_mesh)`
\-- the model-owned CP-attention seam the parallelizer calls (with the CP mesh)
instead of cp\_utils' generic SDPA hooks. `run_cp_manual_attention` is also bound
as the ring entry point.

```python
nemo_automodel.components.models.gemma4_moe.cp_attention.gemma4_vision_group_ids(
    mm_token_type_ids: torch.Tensor
) -> torch.Tensor
```

Return per-image-block ids for Gemma4 vision tokens, or -1 for text/padding.

```python
nemo_automodel.components.models.gemma4_moe.cp_attention._BLOCK_MASK_CACHE: dict = {}
```

```python
nemo_automodel.components.models.gemma4_moe.cp_attention._BLOCK_MASK_GEN: list = [None, None]
```

```python
nemo_automodel.components.models.gemma4_moe.cp_attention._GEMMA4_CP_FLEX_RING_OK_LOGGED = False
```

```python
nemo_automodel.components.models.gemma4_moe.cp_attention.logger = logging.getLogger(__name__)
```