nemo_automodel.distributed.cp_utils#

Module Contents#

Functions#

_build_position_ids

Add position_ids to the batch only if they are missing.

get_train_context

Create a train context.

create_context_parallel_ctx

Create a context parallel context.

make_cp_batch_and_ctx

Build a CP context manager and shards a batch. If the input device_mesh is None or the size of the context_parallel submesh is 1, this function is effectively a no-op.

API#

nemo_automodel.distributed.cp_utils._build_position_ids(batch, device)[source]#

Add position_ids to the batch only if they are missing.

nemo_automodel.distributed.cp_utils.get_train_context(
enable_loss_parallel: bool,
enable_compiled_autograd: bool,
cp_context=None,
)[source]#

Create a train context.

Parameters:
  • enable_loss_parallel (bool) โ€“ Whether to enable loss parallelism.

  • enable_compiled_autograd (bool) โ€“ Whether to enable compiled autograd.

nemo_automodel.distributed.cp_utils.create_context_parallel_ctx(
cp_mesh: torch.distributed.device_mesh.DeviceMesh,
cp_buffers: List[torch.Tensor],
cp_seq_dims: List[int],
cp_no_restore_buffers: Set[torch.Tensor],
cp_rotate_method: str,
)[source]#

Create a context parallel context.

Parameters:
  • cp_mesh (DeviceMesh) โ€“ The device mesh for context parallel.

  • cp_buffers (List[torch.Tensor]) โ€“ The buffers for context parallel.

  • cp_seq_dims (List[int]) โ€“ The sequence dimensions for context parallel.

  • cp_no_restore_buffers (Set[torch.Tensor]) โ€“ The no restore buffers for context parallel.

  • cp_rotate_method (str) โ€“ The rotation method for context parallel, such as โ€œallgatherโ€ or โ€œaddtoallโ€.

nemo_automodel.distributed.cp_utils.make_cp_batch_and_ctx(device_mesh, batch, labels, loss_mask)[source]#

Build a CP context manager and shards a batch. If the input device_mesh is None or the size of the context_parallel submesh is 1, this function is effectively a no-op.

Parameters:
  • cp_mesh (DeviceMesh) โ€“ The device mesh for context parallel.

  • batch (Dict[str, torch.Tensor]) โ€“ The input batch containing (string, torch.Tensor)

Returns:

Returns a tuple with a context manager and a new batch. The context manager is either nullcontext (no CP) or CP context manager as returned by create_context_parallel_ctx. The batch has also been passed to create_context_parallel_ctx and is accordingly sharded.

Return type:

tuple (contextmanager, dict[str, torch.Tensor])