nemo_automodel.components.distributed.cp_utils#

Module Contents#

Functions#

_build_position_ids

Add position_ids to the batch only if they are missing.

get_train_context

Create a train context.

create_context_parallel_ctx

Create a context parallel context.

make_cp_batch_and_ctx

Build a CP context manager and shards a batch. If the input device_mesh is None or the size of the context_parallel submesh is 1, this function is effectively a no-op.

make_cp_batch_for_te

Build a CP batch for Transformer Engine using THD format.

_shard_thd_chunk_for_te

API#

nemo_automodel.components.distributed.cp_utils._build_position_ids(batch, device)#

Add position_ids to the batch only if they are missing.

nemo_automodel.components.distributed.cp_utils.get_train_context(
enable_loss_parallel: bool,
enable_compiled_autograd: bool,
cp_context=None,
)#

Create a train context.

Parameters:
  • enable_loss_parallel (bool) – Whether to enable loss parallelism.

  • enable_compiled_autograd (bool) – Whether to enable compiled autograd.

nemo_automodel.components.distributed.cp_utils.create_context_parallel_ctx(
cp_mesh: torch.distributed.device_mesh.DeviceMesh,
cp_buffers: List[torch.Tensor],
cp_seq_dims: List[int],
cp_no_restore_buffers: Set[torch.Tensor],
cp_rotate_method: Optional[str] = None,
)#

Create a context parallel context.

Parameters:
  • cp_mesh (DeviceMesh) – The device mesh for context parallel.

  • cp_buffers (List[torch.Tensor]) – The buffers for context parallel.

  • cp_seq_dims (List[int]) – The sequence dimensions for context parallel.

  • cp_no_restore_buffers (Set[torch.Tensor]) – The no restore buffers for context parallel.

  • cp_rotate_method (str) – The rotation method for context parallel, such as “allgather” or “addtoall”.

nemo_automodel.components.distributed.cp_utils.make_cp_batch_and_ctx(
device_mesh,
batch,
loss_mask=None,
use_te: bool = False,
padding_token_id: int = 0,
num_chunks: int = 1,
seq_lens_padding_value: int = -1000,
)#

Build a CP context manager and shards a batch. If the input device_mesh is None or the size of the context_parallel submesh is 1, this function is effectively a no-op.

Parameters:
  • cp_mesh (DeviceMesh) – The device mesh for context parallel.

  • batch (Dict[str, torch.Tensor]) – The input batch containing (string, torch.Tensor)

Returns:

Returns a tuple with a context manager and a new batch. The context manager is either nullcontext (no CP) or CP context manager as returned by create_context_parallel_ctx. The batch has also been passed to create_context_parallel_ctx and is accordingly sharded.

Return type:

tuple (contextmanager, dict[str, torch.Tensor])

nemo_automodel.components.distributed.cp_utils.make_cp_batch_for_te(
cp_mesh,
batch,
qkv_format='thd',
padding_token_id: int = 0,
num_chunks: int = 1,
seq_lens_padding_value: int = -1000,
)#

Build a CP batch for Transformer Engine using THD format.

This function converts BSHD format batches to THD format and shards them across context parallel ranks for use with Transformer Engine. It processes the batch in chunks if num_chunks > 1, allowing for better memory efficiency with large sequences.

The function performs three main steps:

  1. Converts BSHD format to THD format using split_batch_into_thd_chunks

  2. Optionally splits the batch into multiple chunks for memory efficiency

  3. Shards each chunk across CP ranks using Transformer Engine’s partitioning

Parameters:
  • cp_mesh (DeviceMesh or None) – The device mesh for context parallel. If None or size <= 1, returns the batch in THD format without sharding.

  • batch (Dict[str, torch.Tensor]) –

    The input batch in BSHD format containing:

    • input_ids: Input token IDs [batch_size, seq_len] or [batch_size, seq_len, hidden_dim]

    • labels: Label token IDs [batch_size, seq_len]

    • position_ids (optional): Position IDs [batch_size, seq_len]

    • seq_lens: Actual sequence lengths [batch_size, num_packs]

    • seq_lens_padded: Padded sequence lengths [batch_size, num_packs]

  • qkv_format (str) – Format for QKV tensors. Currently only “thd” is supported.

  • padding_token_id (int) – Token ID used for padding in input_ids (default: 0)

  • num_chunks (int) – Number of chunks to split the batch into. If > 1, the batch dimension is split and each chunk is processed separately (default: 1)

  • seq_lens_padding_value (int) – Sentinel value used to indicate padding in seq_lens/seq_lens_padded tensors (default: -1000)

Returns:

Processed batch in THD format with the following keys: - input_ids: Sharded input token IDs [total_tokens] or [num_chunks, chunk_tokens] - labels: Sharded labels [total_tokens] or [num_chunks, chunk_tokens] - position_ids: Generated and sharded position IDs [total_tokens] or [num_chunks, chunk_tokens] - cu_seqlens: Cumulative sequence lengths [num_seqs+1] or [num_chunks, max_seqs+1] - cu_seqlens_padded: Cumulative padded sequence lengths [num_seqs+1] or [num_chunks, max_seqs+1] - max_seqlen: Maximum sequence length (int32 tensor) - qkv_format: Format string (“thd”) - padding_mask: Boolean mask indicating padding tokens

Return type:

dict

Raises:
  • ValueError – If qkv_format is not “thd”

  • KeyError – If required fields (seq_lens, seq_lens_padded) are missing from batch

.. rubric:: Example

Single chunk, no CP

batch = { … ‘input_ids’: torch.tensor([[1, 2, 3, 4]]), … ‘labels’: torch.tensor([[2, 3, 4, 5]]), … ‘seq_lens’: torch.tensor([[4]]), … ‘seq_lens_padded’: torch.tensor([[4]]) … } result = make_cp_batch_for_te(None, batch) result[‘input_ids’].shape # [4] in THD format torch.Size([4])

Multiple chunks with CP

batch = { … ‘input_ids’: torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]), … ‘labels’: torch.tensor([[2, 3, 4, 5], [6, 7, 8, 9]]), … ‘seq_lens’: torch.tensor([[4], [4]]), … ‘seq_lens_padded’: torch.tensor([[4], [4]]) … } result = make_cp_batch_for_te(cp_mesh, batch, num_chunks=2) result[‘input_ids’].shape # [2, chunk_tokens] - 2 chunks torch.Size([2, 2]) # Example: 2 chunks, 2 tokens each after sharding

nemo_automodel.components.distributed.cp_utils._shard_thd_chunk_for_te(
batch,
cp_mesh,
qkv_format,
seq_lens_padding_value,
padding_token_id,
)#