nemo_automodel.components.distributed.cp_utils
#
Module Contents#
Functions#
Add position_ids to the batch only if they are missing. |
|
Create a train context. |
|
Create a context parallel context. |
|
Build a CP context manager and shards a batch. If the input device_mesh is None or the size of the context_parallel submesh is 1, this function is effectively a no-op. |
API#
- nemo_automodel.components.distributed.cp_utils._build_position_ids(batch, device)[source]#
Add position_ids to the batch only if they are missing.
- nemo_automodel.components.distributed.cp_utils.get_train_context(
- enable_loss_parallel: bool,
- enable_compiled_autograd: bool,
- cp_context=None,
Create a train context.
- Parameters:
enable_loss_parallel (bool) – Whether to enable loss parallelism.
enable_compiled_autograd (bool) – Whether to enable compiled autograd.
- nemo_automodel.components.distributed.cp_utils.create_context_parallel_ctx(
- cp_mesh: torch.distributed.device_mesh.DeviceMesh,
- cp_buffers: List[torch.Tensor],
- cp_seq_dims: List[int],
- cp_no_restore_buffers: Set[torch.Tensor],
- cp_rotate_method: Optional[str] = None,
Create a context parallel context.
- Parameters:
cp_mesh (DeviceMesh) – The device mesh for context parallel.
cp_buffers (List[torch.Tensor]) – The buffers for context parallel.
cp_seq_dims (List[int]) – The sequence dimensions for context parallel.
cp_no_restore_buffers (Set[torch.Tensor]) – The no restore buffers for context parallel.
cp_rotate_method (str) – The rotation method for context parallel, such as “allgather” or “addtoall”.
- nemo_automodel.components.distributed.cp_utils.make_cp_batch_and_ctx(device_mesh, batch, labels, loss_mask)[source]#
Build a CP context manager and shards a batch. If the input device_mesh is None or the size of the context_parallel submesh is 1, this function is effectively a no-op.
- Parameters:
cp_mesh (DeviceMesh) – The device mesh for context parallel.
batch (Dict[str, torch.Tensor]) – The input batch containing (string, torch.Tensor)
- Returns:
Returns a tuple with a context manager and a new batch. The context manager is either nullcontext (no CP) or CP context manager as returned by
create_context_parallel_ctx
. The batch has also been passed tocreate_context_parallel_ctx
and is accordingly sharded.- Return type:
tuple (contextmanager, dict[str, torch.Tensor])