nemo_automodel.distributed.cp_utils
#
Module Contents#
Functions#
Add position_ids to the batch only if they are missing. |
|
Create a train context. |
|
Create a context parallel context. |
|
Build a CP context manager and shards a batch. If the input device_mesh is None or the size of the context_parallel submesh is 1, this function is effectively a no-op. |
API#
- nemo_automodel.distributed.cp_utils._build_position_ids(batch, device)[source]#
Add position_ids to the batch only if they are missing.
- nemo_automodel.distributed.cp_utils.get_train_context(
- enable_loss_parallel: bool,
- enable_compiled_autograd: bool,
- cp_context=None,
Create a train context.
- Parameters:
enable_loss_parallel (bool) โ Whether to enable loss parallelism.
enable_compiled_autograd (bool) โ Whether to enable compiled autograd.
- nemo_automodel.distributed.cp_utils.create_context_parallel_ctx(
- cp_mesh: torch.distributed.device_mesh.DeviceMesh,
- cp_buffers: List[torch.Tensor],
- cp_seq_dims: List[int],
- cp_no_restore_buffers: Set[torch.Tensor],
- cp_rotate_method: str,
Create a context parallel context.
- Parameters:
cp_mesh (DeviceMesh) โ The device mesh for context parallel.
cp_buffers (List[torch.Tensor]) โ The buffers for context parallel.
cp_seq_dims (List[int]) โ The sequence dimensions for context parallel.
cp_no_restore_buffers (Set[torch.Tensor]) โ The no restore buffers for context parallel.
cp_rotate_method (str) โ The rotation method for context parallel, such as โallgatherโ or โaddtoallโ.
- nemo_automodel.distributed.cp_utils.make_cp_batch_and_ctx(device_mesh, batch, labels, loss_mask)[source]#
Build a CP context manager and shards a batch. If the input device_mesh is None or the size of the context_parallel submesh is 1, this function is effectively a no-op.
- Parameters:
cp_mesh (DeviceMesh) โ The device mesh for context parallel.
batch (Dict[str, torch.Tensor]) โ The input batch containing (string, torch.Tensor)
- Returns:
Returns a tuple with a context manager and a new batch. The context manager is either nullcontext (no CP) or CP context manager as returned by
create_context_parallel_ctx
. The batch has also been passed tocreate_context_parallel_ctx
and is accordingly sharded.- Return type:
tuple (contextmanager, dict[str, torch.Tensor])