nemo_automodel.distributed.parallelizer#

Module Contents#

Functions#

fsdp2_strategy_parallelize

Apply parallelisms and activation checkpointing to the model.

import_classes_from_paths

Helper function to import classes from string paths.

nvfsdp_strategy_parallelize

Apply tensor/data parallelism (nvFSDP) and optional activation-checkpointing to the model.

get_hf_tp_shard_plan

Get the tensor parallel sharding plan from the model.

translate_to_torch_parallel_style

Translates string descriptions to parallelism plans.

to_cpu

Move a tensor or distributed tensor to the CPU.

_destroy_dist_connection

Destroy process group.

Data#

API#

nemo_automodel.distributed.parallelizer.HAVE_NVFSDP#

False

nemo_automodel.distributed.parallelizer.fsdp2_strategy_parallelize(
model,
device_mesh: torch.distributed.device_mesh.DeviceMesh,
mp_policy: torch.distributed.fsdp.MixedPrecisionPolicy = None,
tp_shard_plan: Optional[Dict[str, Union[torch.distributed.tensor.parallel.RowwiseParallel, torch.distributed.tensor.parallel.ColwiseParallel, torch.distributed.tensor.parallel.SequenceParallel]]] = None,
offload_policy: torch.distributed.fsdp.CPUOffloadPolicy = None,
)[source]#

Apply parallelisms and activation checkpointing to the model.

Parameters:
  • model โ€“ The model to be parallelized.

  • device_mesh (DeviceMesh) โ€“ The device mesh for distributed training.

  • mp_policy (MixedPrecisionPolicy) โ€“ Mixed precision policy for model parallelism.

  • tp_shard_plan (Optional[Dict[str, Union[RowwiseParallel, ColwiseParallel, SequenceParallel]]]) โ€“ A tensor parallel sharding plan. The keys should be the module names and the values should be the corresponding parallel styles (e.g., RowwiseParallel, ColwiseParallel, SequenceParallel).

  • offload_policy (CPUOffloadPolicy) โ€“ The offload policy for FSDP. If None, it will use the default policy.

NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory. NOTE: Currently, the user is required to manually handle precision settings such as the mp_policy here because the model parallel strategy does not respect all settings of Fabric(precision=...) at the moment. NOTE: Currently, the user should make sure that custom_tp_plan is compatible with the model architecture.

nemo_automodel.distributed.parallelizer.import_classes_from_paths(class_paths: List[str])[source]#

Helper function to import classes from string paths.

Parameters:

class_paths (List[str]) โ€“ The list of string paths to the classes.

nemo_automodel.distributed.parallelizer.nvfsdp_strategy_parallelize(
model,
device_mesh: torch.distributed.device_mesh.DeviceMesh,
nvfsdp_unit_modules: Optional[List[str]] = None,
tp_shard_plan: Optional[Dict[str, Union[torch.distributed.tensor.parallel.RowwiseParallel, torch.distributed.tensor.parallel.ColwiseParallel, torch.distributed.tensor.parallel.SequenceParallel]]] = None,
data_parallel_sharding_strategy: str = 'optim_grads_params',
init_nvfsdp_with_meta_device: bool = False,
grad_reduce_in_fp32: bool = False,
preserve_fp32_weights: bool = False,
overlap_grad_reduce: bool = True,
overlap_param_gather: bool = True,
check_for_nan_in_grad: bool = True,
average_in_collective: bool = False,
disable_bucketing: bool = False,
calculate_per_token_loss: bool = False,
keep_fp8_transpose_cache_when_using_custom_fsdp: bool = False,
nccl_ub: bool = False,
fsdp_double_buffer: bool = False,
)[source]#

Apply tensor/data parallelism (nvFSDP) and optional activation-checkpointing to the model.

Parameters:
  • model โ€“ The model to be parallelized.

  • device_mesh (DeviceMesh) โ€“ The device mesh describing the physical devices used for distributed training.

  • nvfsdp_unit_modules (Optional[List[str]]) โ€“ Names of sub-modules that should become individual nvFSDP units. If None, the full model is wrapped as a single unit.

  • tp_shard_plan (Optional[Dict[str, Union[RowwiseParallel, ColwiseParallel, SequenceParallel]]]) โ€“ A tensor-parallel sharding plan. Keys are module names; values specify the parallel style to apply (e.g., RowwiseParallel, ColwiseParallel, SequenceParallel).

  • data_parallel_sharding_strategy (str) โ€“ Strategy for sharding parameters, gradients, and optimizer states across data-parallel ranks. Valid options include โ€œparamsโ€, โ€œgrads_paramsโ€, and โ€œoptim_grads_paramsโ€ (default).

  • init_nvfsdp_with_meta_device (bool) โ€“ If True, construct the model on a meta device first and materialize weights lazily to reduce memory fragmentation.

  • grad_reduce_in_fp32 (bool) โ€“ Reduce gradients in FP32 irrespective of the parameter precision to improve numerical stability.

  • preserve_fp32_weights (bool) โ€“ Keep a master FP32 copy of weights when training in reduced precision (e.g., FP16/BF16).

  • overlap_grad_reduce (bool) โ€“ If True, overlap gradient reduction with backward computation.

  • overlap_param_gather (bool) โ€“ If True, overlap parameter gathering with forward computation.

  • check_for_nan_in_grad (bool) โ€“ Whether to check gradients for NaNs/Infs before applying the optimizer step.

  • average_in_collective (bool) โ€“ Perform gradient averaging inside the collective operation instead of dividing afterward.

  • disable_bucketing (bool) โ€“ Disable gradient bucketing; gradients are reduced immediately as they are produced.

  • calculate_per_token_loss (bool) โ€“ Compute loss normalized by the number of tokens instead of the number of sequences.

  • keep_fp8_transpose_cache_when_using_custom_fsdp (bool) โ€“ Retain the FP8 transpose cache when using a custom nvFSDP wrapper.

  • nccl_ub (bool) โ€“ Enable NCCL user-buffer API (experimental) for reduced latency on some networks.

  • fsdp_double_buffer (bool) โ€“ Enable double buffering of parameters to overlap communication and computation in nvFSDP.

NOTE: The passed-in model should preferably reside on the meta device. Otherwise, ensure the model fits into available GPU or CPU memory.

NOTE: The user must ensure that the provided tp_shard_plan is compatible with the model architecture.

nemo_automodel.distributed.parallelizer.get_hf_tp_shard_plan(model)[source]#

Get the tensor parallel sharding plan from the model.

nemo_automodel.distributed.parallelizer.translate_to_torch_parallel_style(style: str)[source]#

Translates string descriptions to parallelism plans.

In model configurations, we use a neutral type (string) to specify parallel styles, here we translate them into torch.distributed tensor-parallel types.

nemo_automodel.distributed.parallelizer.to_cpu(v)[source]#

Move a tensor or distributed tensor to the CPU.

This function takes an input tensor, which can be either a DTensor (distributed tensor) or a standard Tensor, and ensures that it is moved to the CPU.

Parameters:

v (DTensor | Tensor | any) โ€“ The input value, which can be a DTensor, Tensor, or any other object. If DTensor, it checks the device and moves the tensor accordingly.

Returns:

The corresponding CPU tensor if v is a DTensor or Tensor, otherwise returns v unchanged.

Return type:

Tensor | any

Raises:

ValueError โ€“ If v is a DTensor but its device is neither โ€˜cudaโ€™ nor โ€˜cpuโ€™.

.. rubric:: Example

t = torch.tensor([1, 2, 3], device=โ€™cudaโ€™) to_cpu(t) # Moves tensor to CPU tensor([1, 2, 3])

dt = DTensor(torch.tensor([4, 5, 6], device=โ€™cudaโ€™)) to_cpu(dt) # Moves DTensor to CPU tensor([4, 5, 6])

nemo_automodel.distributed.parallelizer._destroy_dist_connection() None[source]#

Destroy process group.