> For clean Markdown of any page, append .md to the page URL.
> For a complete documentation index, see https://docs.nvidia.com/nemo/automodel/llms.txt.
> For AI client integration (Claude Code, Cursor, etc.), connect to the MCP server at https://docs.nvidia.com/nemo/automodel/_mcp/server.

# nemo_automodel.components.distributed.cp_utils

## Module Contents

### Functions

| Name                                                                                                             | Description                                                                                    |
| ---------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------- |
| [`_build_position_ids`](#nemo_automodel-components-distributed-cp_utils-_build_position_ids)                     | Add position\_ids to the batch only if they are missing.                                       |
| [`_shard_thd_chunk_for_te`](#nemo_automodel-components-distributed-cp_utils-_shard_thd_chunk_for_te)             | -                                                                                              |
| [`attach_context_parallel_hooks`](#nemo_automodel-components-distributed-cp_utils-attach_context_parallel_hooks) | Attach forward pre-hooks to self\_attn modules to fix attention masks for context parallelism. |
| [`attach_cp_sdpa_hooks`](#nemo_automodel-components-distributed-cp_utils-attach_cp_sdpa_hooks)                   | Inject CP-aware SDPA into self\_attn modules for compile + CP>1 correctness.                   |
| [`create_context_parallel_ctx`](#nemo_automodel-components-distributed-cp_utils-create_context_parallel_ctx)     | Create a context parallel context.                                                             |
| [`gather_cp_seq`](#nemo_automodel-components-distributed-cp_utils-gather_cp_seq)                                 | Gather context-parallel sharded `tensors` back to the full sequence.                           |
| [`get_train_context`](#nemo_automodel-components-distributed-cp_utils-get_train_context)                         | Create a train context.                                                                        |
| [`make_cp_batch_and_ctx`](#nemo_automodel-components-distributed-cp_utils-make_cp_batch_and_ctx)                 | Build a CP context manager and shards a batch. If the input device\_mesh is None or the size   |
| [`make_cp_batch_for_te`](#nemo_automodel-components-distributed-cp_utils-make_cp_batch_for_te)                   | Build a CP batch for Transformer Engine using THD format.                                      |
| [`make_target_cp_ctx`](#nemo_automodel-components-distributed-cp_utils-make_target_cp_ctx)                       | Build a context-parallel context for a frozen target forward.                                  |

### API

```python
nemo_automodel.components.distributed.cp_utils._build_position_ids(
    batch,
    device
)
```

Add position\_ids to the batch only if they are missing.

```python
nemo_automodel.components.distributed.cp_utils._shard_thd_chunk_for_te(
    batch,
    cp_mesh,
    qkv_format,
    seq_lens_padding_value,
    padding_token_id
)
```

```python
nemo_automodel.components.distributed.cp_utils.attach_context_parallel_hooks(
    model: torch.nn.Module
)
```

Attach forward pre-hooks to self\_attn modules to fix attention masks for context parallelism.

Context parallelism shards Q/K/V on the sequence dimension as DTensors,
so explicit 4D attention masks would have mismatched shapes.  This function
registers a hook on every `self_attn` sub-module that strips the
`attention_mask` kwarg and sets `is_causal=True` instead, letting
SDPA handle causal masking internally.

Based on `accelerate.big_modeling._attach_context_parallel_hooks`.

```python
nemo_automodel.components.distributed.cp_utils.attach_cp_sdpa_hooks(
    model: torch.nn.Module,
    cp_mesh
) -> None
```

Inject CP-aware SDPA into self\_attn modules for compile + CP>1 correctness.

Problem: when per-layer torch.compile is active, Dynamo traces through the decoder
layer including Q/K/V projections.  At the F.scaled\_dot\_product\_attention call site,
Q/K/V are already local tensors (DTensor metadata was never propagated through the
compiled graph).  The DTensor SDPA dispatch — which triggers the CP allgather — never
fires, so each rank silently attends only to its local sequence shard.

Fix: swap F.scaled\_dot\_product\_attention with a @torch.\_dynamo.disable wrapper for
the duration of each self\_attn forward.  Dynamo sees the disabled function and creates
a graph break there, so:

* Everything before (Q/K/V proj + RoPE) is compiled and fused.
* The disabled wrapper runs eagerly: re-wraps local Q/K/V as DTensors with
  Shard(2) on the CP mesh so the DTensor SDPA dispatch fires the allgather.
* Everything after (O proj + residual + MLP) is compiled and fused.

Seq dim at the SDPA call is 2: tensors are \[B, nH, S/cp\_size, D] after HF reshape.

```python
nemo_automodel.components.distributed.cp_utils.create_context_parallel_ctx(
    cp_mesh: torch.distributed.device_mesh.DeviceMesh,
    cp_buffers: typing.List[torch.Tensor],
    cp_seq_dims: typing.List[int],
    cp_no_restore_buffers: typing.Set[torch.Tensor],
    cp_rotate_method: typing.Optional[str] = None
)
```

Create a context parallel context.

**Parameters:**

The device mesh for context parallel.

The buffers for context parallel.

The sequence dimensions for context parallel.

The no restore buffers for context parallel.

The rotation method for context parallel,
such as "allgather" or "addtoall".

```python
nemo_automodel.components.distributed.cp_utils.gather_cp_seq(
    cp_mesh: torch.distributed.device_mesh.DeviceMesh,
    tensors: typing.List[torch.Tensor],
    seq_dim: int,
    orig_len: int
)
```

Gather context-parallel sharded `tensors` back to the full sequence.

Inverse of the sharding done by :func:`make_target_cp_ctx`. Uses torch's
`context_parallel_unshard` with `load_balancer=None` (matching the
load-balancing-disabled sharding) and slices the right-pad back off.

**Parameters:**

The context-parallel device (sub)mesh used to shard.

Local-shard tensors (e.g. captured aux hidden states, logits),
each sharded to `T/cp` along `seq_dim`.

The sequence dimension to gather along.

The pre-pad sequence length to slice back to.

**Returns:**

A list of full-sequence tensors of length `orig_len` along `seq_dim`.

```python
nemo_automodel.components.distributed.cp_utils.get_train_context(
    enable_loss_parallel: bool,
    enable_compiled_autograd: bool,
    cp_context = None
)
```

Create a train context.

**Parameters:**

Whether to enable loss parallelism.

Whether to enable compiled autograd.

```python
nemo_automodel.components.distributed.cp_utils.make_cp_batch_and_ctx(
    device_mesh,
    batch,
    loss_mask = None,
    use_te: bool = False,
    padding_token_id: int = 0,
    num_chunks: int = 1,
    seq_lens_padding_value: int = -1000
)
```

Build a CP context manager and shards a batch. If the input device\_mesh is None or the size
of the context\_parallel submesh is 1, this function is effectively a no-op.

**Parameters:**

The device mesh for context parallel.

The input batch containing (string, torch.Tensor)

**Returns:** `(contextmanager, dict[str, torch.Tensor])`

Returns a tuple with a context manager

```python
nemo_automodel.components.distributed.cp_utils.make_cp_batch_for_te(
    cp_mesh,
    batch,
    qkv_format = 'thd',
    padding_token_id: int = 0,
    num_chunks: int = 1,
    seq_lens_padding_value: int = -1000
)
```

Build a CP batch for Transformer Engine using THD format.

This function converts BSHD format batches to THD format and shards them across
context parallel ranks for use with Transformer Engine. It processes the batch
in chunks if num\_chunks > 1, allowing for better memory efficiency with large
sequences.

The function performs three main steps:

1. Converts BSHD format to THD format using split\_batch\_into\_thd\_chunks
2. Optionally splits the batch into multiple chunks for memory efficiency
3. Shards each chunk across CP ranks using Transformer Engine's partitioning

**Parameters:**

The device mesh for context parallel. If None or
size \<= 1, returns the batch in THD format without sharding.

The input batch in BSHD format containing:

* input\_ids: Input token IDs \[batch\_size, seq\_len] or \[batch\_size, seq\_len, hidden\_dim]
* labels: Label token IDs \[batch\_size, seq\_len]
* position\_ids (optional): Position IDs \[batch\_size, seq\_len]
* seq\_lens: Actual sequence lengths \[batch\_size, num\_packs]
* seq\_lens\_padded: Padded sequence lengths \[batch\_size, num\_packs]

Format for QKV tensors. Currently only "thd" is supported.

Token ID used for padding in input\_ids (default: 0)

Number of chunks to split the batch into. If > 1, the batch
dimension is split and each chunk is processed separately (default: 1)

Sentinel value used to indicate padding in
seq\_lens/seq\_lens\_padded tensors (default: -1000)

**Returns:**

Processed batch in THD format with the following keys:

* input\_ids: Sharded input token IDs \[total\_tokens] or \[num\_chunks, chunk\_tokens]
* labels: Sharded labels \[total\_tokens] or \[num\_chunks, chunk\_tokens]
* position\_ids: Generated and sharded position IDs \[total\_tokens] or \[num\_chunks, chunk\_tokens]
* cu\_seqlens: Cumulative sequence lengths \[num\_seqs+1] or \[num\_chunks, max\_seqs+1]
* cu\_seqlens\_padded: Cumulative padded sequence lengths \[num\_seqs+1] or \[num\_chunks, max\_seqs+1]
* max\_seqlen: Maximum sequence length (int32 tensor)
* qkv\_format: Format string ("thd")
* padding\_mask: Boolean mask indicating padding tokens

**Raises:**

* `ValueError`: If qkv\_format is not "thd"
* `KeyError`: If required fields (seq\_lens, seq\_lens\_padded) are missing from batch

```python
nemo_automodel.components.distributed.cp_utils.make_target_cp_ctx(
    cp_mesh: torch.distributed.device_mesh.DeviceMesh,
    input_ids,
    position_ids = None
)
```

Build a context-parallel context for a frozen target forward.

Shards `input_ids` (and `position_ids`) along the sequence dim across
`cp_mesh` so the target's self-attention runs as ring attention. Unlike
:func:`make_cp_batch_and_ctx`, this does not require `labels` and is meant
for the EAGLE-3 target wrapper, which gathers the aux/logits back to the full
sequence (see :func:`gather_cp_seq`) before handing them to the draft.

Load balancing is disabled (`_cp_options.enable_load_balance = False`) so
each rank holds a contiguous sequence chunk and the gather is a plain ordered
concat (no round-robin un-permute). The sharding is thrown away right after
the forward, so load balancing buys nothing here, and the ordered shard makes
the gather deterministic. This is a process-global torch flag; the EAGLE-3
recipe is the only context-parallel user in its process.

The sequence is right-padded to a multiple of `cp_size`; the returned
`orig_len` lets the caller slice the gathered outputs back down.

**Parameters:**

The context-parallel device (sub)mesh.

`[B, T]` token ids.

Optional `[B, T]` (or `[1, T]`) position ids; an arange
is injected when omitted.

**Returns:**

`(cp_ctx, sharded_input_ids, sharded_position_ids, orig_len)`. Enter