nemo_automodel.components.distributed.cp_utils

View as Markdown

Module Contents

Functions

NameDescription
_build_position_idsAdd position_ids to the batch only if they are missing.
_shard_thd_chunk_for_te-
attach_context_parallel_hooksAttach forward pre-hooks to self_attn modules to fix attention masks for context parallelism.
attach_cp_sdpa_hooksInject CP-aware SDPA into self_attn modules for compile + CP>1 correctness.
create_context_parallel_ctxCreate a context parallel context.
gather_cp_seqGather context-parallel sharded tensors back to the full sequence.
get_train_contextCreate a train context.
make_cp_batch_and_ctxBuild a CP context manager and shards a batch. If the input device_mesh is None or the size
make_cp_batch_for_teBuild a CP batch for Transformer Engine using THD format.
make_target_cp_ctxBuild a context-parallel context for a frozen target forward.

API

nemo_automodel.components.distributed.cp_utils._build_position_ids(
batch,
device
)

Add position_ids to the batch only if they are missing.

nemo_automodel.components.distributed.cp_utils._shard_thd_chunk_for_te(
batch,
cp_mesh,
qkv_format,
seq_lens_padding_value,
padding_token_id
)
nemo_automodel.components.distributed.cp_utils.attach_context_parallel_hooks(
model: torch.nn.Module
)

Attach forward pre-hooks to self_attn modules to fix attention masks for context parallelism.

Context parallelism shards Q/K/V on the sequence dimension as DTensors, so explicit 4D attention masks would have mismatched shapes. This function registers a hook on every self_attn sub-module that strips the attention_mask kwarg and sets is_causal=True instead, letting SDPA handle causal masking internally.

Based on accelerate.big_modeling._attach_context_parallel_hooks.

nemo_automodel.components.distributed.cp_utils.attach_cp_sdpa_hooks(
model: torch.nn.Module,
cp_mesh
) -> None

Inject CP-aware SDPA into self_attn modules for compile + CP>1 correctness.

Problem: when per-layer torch.compile is active, Dynamo traces through the decoder layer including Q/K/V projections. At the F.scaled_dot_product_attention call site, Q/K/V are already local tensors (DTensor metadata was never propagated through the compiled graph). The DTensor SDPA dispatch — which triggers the CP allgather — never fires, so each rank silently attends only to its local sequence shard.

Fix: swap F.scaled_dot_product_attention with a @torch._dynamo.disable wrapper for the duration of each self_attn forward. Dynamo sees the disabled function and creates a graph break there, so:

  • Everything before (Q/K/V proj + RoPE) is compiled and fused.
  • The disabled wrapper runs eagerly: re-wraps local Q/K/V as DTensors with Shard(2) on the CP mesh so the DTensor SDPA dispatch fires the allgather.
  • Everything after (O proj + residual + MLP) is compiled and fused.

Seq dim at the SDPA call is 2: tensors are [B, nH, S/cp_size, D] after HF reshape.

nemo_automodel.components.distributed.cp_utils.create_context_parallel_ctx(
cp_mesh: torch.distributed.device_mesh.DeviceMesh,
cp_buffers: typing.List[torch.Tensor],
cp_seq_dims: typing.List[int],
cp_no_restore_buffers: typing.Set[torch.Tensor],
cp_rotate_method: typing.Optional[str] = None
)

Create a context parallel context.

Parameters:

cp_mesh
DeviceMesh

The device mesh for context parallel.

cp_buffers
List[torch.Tensor]

The buffers for context parallel.

cp_seq_dims
List[int]

The sequence dimensions for context parallel.

cp_no_restore_buffers
Set[torch.Tensor]

The no restore buffers for context parallel.

cp_rotate_method
strDefaults to None

The rotation method for context parallel, such as “allgather” or “addtoall”.

nemo_automodel.components.distributed.cp_utils.gather_cp_seq(
cp_mesh: torch.distributed.device_mesh.DeviceMesh,
tensors: typing.List[torch.Tensor],
seq_dim: int,
orig_len: int
)

Gather context-parallel sharded tensors back to the full sequence.

Inverse of the sharding done by :func:make_target_cp_ctx. Uses torch’s context_parallel_unshard with load_balancer=None (matching the load-balancing-disabled sharding) and slices the right-pad back off.

Parameters:

cp_mesh
DeviceMesh

The context-parallel device (sub)mesh used to shard.

tensors
List[torch.Tensor]

Local-shard tensors (e.g. captured aux hidden states, logits), each sharded to T/cp along seq_dim.

seq_dim
int

The sequence dimension to gather along.

orig_len
int

The pre-pad sequence length to slice back to.

Returns:

A list of full-sequence tensors of length orig_len along seq_dim.

nemo_automodel.components.distributed.cp_utils.get_train_context(
enable_loss_parallel: bool,
enable_compiled_autograd: bool,
cp_context = None
)

Create a train context.

Parameters:

enable_loss_parallel
bool

Whether to enable loss parallelism.

enable_compiled_autograd
bool

Whether to enable compiled autograd.

nemo_automodel.components.distributed.cp_utils.make_cp_batch_and_ctx(
device_mesh,
batch,
loss_mask = None,
use_te: bool = False,
padding_token_id: int = 0,
num_chunks: int = 1,
seq_lens_padding_value: int = -1000
)

Build a CP context manager and shards a batch. If the input device_mesh is None or the size of the context_parallel submesh is 1, this function is effectively a no-op.

Parameters:

cp_mesh
DeviceMesh

The device mesh for context parallel.

batch
Dict[str, torch.Tensor]

The input batch containing (string, torch.Tensor)

Returns: (contextmanager, dict[str, torch.Tensor])

Returns a tuple with a context manager

nemo_automodel.components.distributed.cp_utils.make_cp_batch_for_te(
cp_mesh,
batch,
qkv_format = 'thd',
padding_token_id: int = 0,
num_chunks: int = 1,
seq_lens_padding_value: int = -1000
)

Build a CP batch for Transformer Engine using THD format.

This function converts BSHD format batches to THD format and shards them across context parallel ranks for use with Transformer Engine. It processes the batch in chunks if num_chunks > 1, allowing for better memory efficiency with large sequences.

The function performs three main steps:

  1. Converts BSHD format to THD format using split_batch_into_thd_chunks
  2. Optionally splits the batch into multiple chunks for memory efficiency
  3. Shards each chunk across CP ranks using Transformer Engine’s partitioning

Parameters:

cp_mesh
DeviceMesh or None

The device mesh for context parallel. If None or size <= 1, returns the batch in THD format without sharding.

batch
Dict[str, torch.Tensor]

The input batch in BSHD format containing:

  • input_ids: Input token IDs [batch_size, seq_len] or [batch_size, seq_len, hidden_dim]
  • labels: Label token IDs [batch_size, seq_len]
  • position_ids (optional): Position IDs [batch_size, seq_len]
  • seq_lens: Actual sequence lengths [batch_size, num_packs]
  • seq_lens_padded: Padded sequence lengths [batch_size, num_packs]
qkv_format
strDefaults to 'thd'

Format for QKV tensors. Currently only “thd” is supported.

padding_token_id
intDefaults to 0

Token ID used for padding in input_ids (default: 0)

num_chunks
intDefaults to 1

Number of chunks to split the batch into. If > 1, the batch dimension is split and each chunk is processed separately (default: 1)

seq_lens_padding_value
intDefaults to -1000

Sentinel value used to indicate padding in seq_lens/seq_lens_padded tensors (default: -1000)

Returns:

Processed batch in THD format with the following keys:

  • input_ids: Sharded input token IDs [total_tokens] or [num_chunks, chunk_tokens]
  • labels: Sharded labels [total_tokens] or [num_chunks, chunk_tokens]
  • position_ids: Generated and sharded position IDs [total_tokens] or [num_chunks, chunk_tokens]
  • cu_seqlens: Cumulative sequence lengths [num_seqs+1] or [num_chunks, max_seqs+1]
  • cu_seqlens_padded: Cumulative padded sequence lengths [num_seqs+1] or [num_chunks, max_seqs+1]
  • max_seqlen: Maximum sequence length (int32 tensor)
  • qkv_format: Format string (“thd”)
  • padding_mask: Boolean mask indicating padding tokens

Raises:

  • ValueError: If qkv_format is not “thd”
  • KeyError: If required fields (seq_lens, seq_lens_padded) are missing from batch
nemo_automodel.components.distributed.cp_utils.make_target_cp_ctx(
cp_mesh: torch.distributed.device_mesh.DeviceMesh,
input_ids,
position_ids = None
)

Build a context-parallel context for a frozen target forward.

Shards input_ids (and position_ids) along the sequence dim across cp_mesh so the target’s self-attention runs as ring attention. Unlike :func:make_cp_batch_and_ctx, this does not require labels and is meant for the EAGLE-3 target wrapper, which gathers the aux/logits back to the full sequence (see :func:gather_cp_seq) before handing them to the draft.

Load balancing is disabled (_cp_options.enable_load_balance = False) so each rank holds a contiguous sequence chunk and the gather is a plain ordered concat (no round-robin un-permute). The sharding is thrown away right after the forward, so load balancing buys nothing here, and the ordered shard makes the gather deterministic. This is a process-global torch flag; the EAGLE-3 recipe is the only context-parallel user in its process.

The sequence is right-padded to a multiple of cp_size; the returned orig_len lets the caller slice the gathered outputs back down.

Parameters:

cp_mesh
DeviceMesh

The context-parallel device (sub)mesh.

input_ids

[B, T] token ids.

position_ids
Defaults to None

Optional [B, T] (or [1, T]) position ids; an arange is injected when omitted.

Returns:

(cp_ctx, sharded_input_ids, sharded_position_ids, orig_len). Enter