nemo_automodel.components.distributed.parallelizer#

Module Contents#

Classes#

ParallelizationStrategy

Abstract base class for model parallelization strategies.

DefaultParallelizationStrategy

Default parallelization strategy used by most models.

NemotronHParallelizationStrategy

Specialized parallelization strategy for NemotronH models.

WanParallelizationStrategy

Parallelization strategy for Wan-style transformer modules used in Diffusers.

Functions#

get_parallelization_strategy

Get the appropriate parallelization strategy for the given model.

apply_fsdp2_sharding_recursively

Recursively apply FSDP2 sharding to modules, with optimizations for ModuleList.

get_hf_tp_shard_plan

Get the Hugging Face tensor parallel plan from the model.

import_class_from_path

Import a class from a string path (e.g. ‘torch.optim.AdamW’).

import_classes_from_paths

Helper function to import classes from string paths.

translate_to_torch_parallel_style

Translates string descriptions to parallelism plans.

validate_tp_mesh_for_nemotron_nas

validate_tp_mesh

Validate that attention heads and key value heads are divisible by TP size

_find_largest_module_list

Heuristic function to find the largest nn.ModuleList in a model.

_extract_model_layers

Extract layers from different model architectures for parallelization.

_get_parallel_plan

Select the tensor-parallel plan for the given model.

fsdp2_strategy_parallelize

Apply parallelisms and activation checkpointing to the model.

megatron_fsdp_strategy_parallelize

Apply tensor/data parallelism (MegatronFSDP) and optional activation-checkpointing to the model.

unshard_fsdp2_model

Explicitly unshard and then reshard the FSDP2 modules. Useful for logprob inference.

Data#

API#

nemo_automodel.components.distributed.parallelizer.HAVE_MEGATRON_FSDP#

False

nemo_automodel.components.distributed.parallelizer.logger#

‘getLogger(…)’

class nemo_automodel.components.distributed.parallelizer.ParallelizationStrategy#

Bases: abc.ABC

Abstract base class for model parallelization strategies.

abstractmethod parallelize(
model: torch.nn.Module,
device_mesh: torch.distributed.device_mesh.DeviceMesh,
mp_policy: Optional[torch.distributed.fsdp.MixedPrecisionPolicy] = None,
offload_policy: Optional[torch.distributed.fsdp.OffloadPolicy] = None,
sequence_parallel: bool = False,
activation_checkpointing: bool = False,
tp_shard_plan: Optional[Union[Dict[str, torch.distributed.tensor.parallel.ParallelStyle], str]] = None,
use_hf_tp_plan: bool = False,
dp_replicate_mesh_name: str = 'dp_replicate',
dp_shard_cp_mesh_name: str = 'dp_shard_cp',
tp_mesh_name: str = 'tp',
) torch.nn.Module#

Apply parallelization strategy to the model.

class nemo_automodel.components.distributed.parallelizer.DefaultParallelizationStrategy#

Bases: nemo_automodel.components.distributed.parallelizer.ParallelizationStrategy

Default parallelization strategy used by most models.

parallelize(
model: torch.nn.Module,
device_mesh: torch.distributed.device_mesh.DeviceMesh,
mp_policy: Optional[torch.distributed.fsdp.MixedPrecisionPolicy] = None,
offload_policy: Optional[torch.distributed.fsdp.OffloadPolicy] = None,
sequence_parallel: bool = False,
activation_checkpointing: bool = False,
tp_shard_plan: Optional[Union[Dict[str, torch.distributed.tensor.parallel.ParallelStyle], str]] = None,
use_hf_tp_plan: bool = False,
dp_replicate_mesh_name: str = 'dp_replicate',
dp_shard_cp_mesh_name: str = 'dp_shard_cp',
tp_mesh_name: str = 'tp',
) torch.nn.Module#

Apply the default parallelization flow.

class nemo_automodel.components.distributed.parallelizer.NemotronHParallelizationStrategy#

Bases: nemo_automodel.components.distributed.parallelizer.ParallelizationStrategy

Specialized parallelization strategy for NemotronH models.

parallelize(
model: torch.nn.Module,
device_mesh: torch.distributed.device_mesh.DeviceMesh,
mp_policy: Optional[torch.distributed.fsdp.MixedPrecisionPolicy] = None,
offload_policy: Optional[torch.distributed.fsdp.OffloadPolicy] = None,
sequence_parallel: bool = False,
activation_checkpointing: bool = False,
tp_shard_plan: Optional[Union[Dict[str, torch.distributed.tensor.parallel.ParallelStyle], str]] = None,
dp_replicate_mesh_name: str = 'dp_replicate',
dp_shard_cp_mesh_name: str = 'dp_shard_cp',
tp_mesh_name: str = 'tp',
) torch.nn.Module#

Apply NemotronH-specific parallelization.

class nemo_automodel.components.distributed.parallelizer.WanParallelizationStrategy#

Bases: nemo_automodel.components.distributed.parallelizer.ParallelizationStrategy

Parallelization strategy for Wan-style transformer modules used in Diffusers.

Applies TP to condition embedders, FFN projections in each block, and final projection, then applies FSDP sharding similarly to other strategies.

parallelize(
model: torch.nn.Module,
device_mesh: torch.distributed.device_mesh.DeviceMesh,
mp_policy: Optional[torch.distributed.fsdp.MixedPrecisionPolicy] = None,
offload_policy: Optional[torch.distributed.fsdp.OffloadPolicy] = None,
sequence_parallel: bool = False,
activation_checkpointing: bool = False,
tp_shard_plan: Optional[Union[Dict[str, torch.distributed.tensor.parallel.ParallelStyle], str]] = None,
dp_replicate_mesh_name: str = 'dp_replicate',
dp_shard_cp_mesh_name: str = 'dp_shard_cp',
tp_mesh_name: str = 'tp',
) torch.nn.Module#
nemo_automodel.components.distributed.parallelizer.PARALLELIZATION_STRATEGIES: Dict[str, nemo_automodel.components.distributed.parallelizer.ParallelizationStrategy]#

None

nemo_automodel.components.distributed.parallelizer._DEFAULT_STRATEGY#

‘DefaultParallelizationStrategy(…)’

nemo_automodel.components.distributed.parallelizer.get_parallelization_strategy(
model: torch.nn.Module,
) nemo_automodel.components.distributed.parallelizer.ParallelizationStrategy#

Get the appropriate parallelization strategy for the given model.

nemo_automodel.components.distributed.parallelizer.apply_fsdp2_sharding_recursively(
module: torch.nn.Module,
mesh: torch.distributed.device_mesh.DeviceMesh,
mp_policy: Optional[torch.distributed.fsdp.MixedPrecisionPolicy],
offload_policy: Optional[torch.distributed.fsdp.OffloadPolicy] = None,
) None#

Recursively apply FSDP2 sharding to modules, with optimizations for ModuleList.

This utility function traverses a model hierarchy and applies FSDP2 sharding to each module. For ModuleList instances (commonly used for transformer layers), it applies an optimization where the last layer doesn’t reshard after forward since FSDP2 will prefetch it immediately.

Parameters:
  • module (nn.Module) – The module to apply FSDP sharding to.

  • mesh (DeviceMesh) – The device mesh for FSDP sharding.

  • mp_policy (Optional[MixedPrecisionPolicy]) – Mixed precision policy for FSDP.

  • offload_policy (Optional[OffloadPolicy]) – CPU offload policy for FSDP. Defaults to None.

.. note::

This function modifies the module in-place by replacing modules with their FSDP2-subclassed versions.

nemo_automodel.components.distributed.parallelizer.get_hf_tp_shard_plan(model)#

Get the Hugging Face tensor parallel plan from the model.

This function:

  • Retrieves TP strategies from model class, instance, and inner model levels.

  • Handles special cases for embed_tokens and lm_head for speed up.

  • Converts string-based parallel styles to DTensor parallelization strategies.

Taken and modified from: https://github.com/NVIDIA/NeMo/blob/6c6169db01bcca73ae8ad3ac35242fadbb9a78ba/nemo/lightning/pytorch/strategies/utils.py#L532

Parameters:

model – A Hugging Face model instance

Returns:

A dictionary mapping model component paths to their parallelization strategies

Return type:

dict

Raises:

AssertionError – If no TP plan is found

nemo_automodel.components.distributed.parallelizer.import_class_from_path(name: str) Any#

Import a class from a string path (e.g. ‘torch.optim.AdamW’).

Parameters:

full_path – Full path to class including module path and class name

Returns:

The imported class object

nemo_automodel.components.distributed.parallelizer.import_classes_from_paths(class_paths: List[str])#

Helper function to import classes from string paths.

Parameters:

class_paths (List[str]) – The list of string paths to the classes.

Returns:

List of imported classes.

nemo_automodel.components.distributed.parallelizer.translate_to_torch_parallel_style(style: str)#

Translates string descriptions to parallelism plans.

In model configurations, we use a neutral type (string) to specify parallel styles, here we translate them into torch.distributed tensor-parallel types.

nemo_automodel.components.distributed.parallelizer.validate_tp_mesh_for_nemotron_nas(model, tp_size)#
nemo_automodel.components.distributed.parallelizer.validate_tp_mesh(model, tp_mesh)#

Validate that attention heads and key value heads are divisible by TP size

nemo_automodel.components.distributed.parallelizer._find_largest_module_list(
model: torch.nn.Module,
) Optional[torch.nn.ModuleList]#

Heuristic function to find the largest nn.ModuleList in a model.

This function recursively traverses the model to find all nn.ModuleList instances and returns the one with the most modules. This is useful as a fallback when the model architecture is unknown, since transformer layers are typically organized in ModuleLists.

Parameters:

model (nn.Module) – The model to search through.

Returns:

The largest ModuleList found, or None if no ModuleList exists.

Return type:

Optional[nn.ModuleList]

nemo_automodel.components.distributed.parallelizer._extract_model_layers(
model: torch.nn.Module,
) List[torch.nn.Module]#

Extract layers from different model architectures for parallelization.

This function handles various model types including vision-language models, causal language models, and multimodal models. It collects both language model layers and vision model layers where applicable.

Parameters:

model (nn.Module) – The model to extract layers from.

Returns:

A list of all layers that should be parallelized.

Return type:

List[nn.Module]

nemo_automodel.components.distributed.parallelizer._get_parallel_plan(
model: torch.nn.Module,
sequence_parallel: bool = False,
tp_shard_plan: Optional[Union[Dict[str, torch.distributed.tensor.parallel.ParallelStyle], str]] = None,
use_hf_tp_plan: bool = False,
) Dict[str, torch.distributed.tensor.parallel.ParallelStyle]#

Select the tensor-parallel plan for the given model.

Priority order:

  1. If tp_shard_plan is provided as a dict or import path (to a dict/function), use it.

  2. If use_hf_tp_plan is True, use the HF plan directly (asserts when sequence_parallel=True).

  3. If the model type exists in PARALLELIZE_FUNCTIONS, use its optimised plan; on failure, try HF plan

  4. Otherwise, use the default base plan.

nemo_automodel.components.distributed.parallelizer.fsdp2_strategy_parallelize(
model,
device_mesh: torch.distributed.device_mesh.DeviceMesh,
mp_policy: Optional[torch.distributed.fsdp.MixedPrecisionPolicy] = None,
offload_policy: Optional[torch.distributed.fsdp.OffloadPolicy] = None,
sequence_parallel: bool = False,
activation_checkpointing: bool = False,
tp_shard_plan: Optional[Union[Dict[str, torch.distributed.tensor.parallel.ParallelStyle], str]] = None,
dp_replicate_mesh_name: str = 'dp_replicate',
dp_shard_cp_mesh_name: str = 'dp_shard_cp',
tp_mesh_name: str = 'tp',
)#

Apply parallelisms and activation checkpointing to the model.

Enhanced version that uses a strategy pattern for different model parallelization approaches:

  • Automatic strategy selection based on model type

  • Polymorphic parallelization strategies for different model families

  • Custom parallel plan support (dict or string path)

  • Sequence parallel support

  • Activation checkpointing for linear layers

  • Model validation (attention heads divisible by TP size)

  • Better fallback logic

Parameters:
  • model – The model to be parallelized.

  • device_mesh (DeviceMesh) – The device mesh for distributed training.

  • mp_policy (Optional[MixedPrecisionPolicy]) – Mixed precision policy for model parallelism.

  • offload_policy (Optional[OffloadPolicy]) – The offload policy for FSDP.

  • sequence_parallel (bool) – Whether to use sequence parallelism. Defaults to False.

  • activation_checkpointing (bool) – Whether to use activation checkpointing. Defaults to False.

  • tp_shard_plan (Optional[Union[Dict[str, ParallelStyle], str]]) –

    Custom tensor parallel plan for the model. Can be:

    • A dictionary mapping module names to parallel styles

    • A string path to a dictionary or function that returns a dictionary If provided, this takes precedence over automatic plan generation.

  • dp_replicate_mesh_name (str) – Key name for the data parallel replicate mesh in device_mesh. Used when data parallel replicate is enabled. Defaults to “dp_replicate”.

  • dp_shard_cp_mesh_name (str) – Key name for the data parallel shard + context parallel mesh in device_mesh. Used when data parallel shard is enabled. Defaults to “dp_shard_cp”.

  • tp_mesh_name (str) – Key name for the tensor parallel mesh in device_mesh. Defaults to “tp”.

Returns:

The parallelized model.

NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory.

nemo_automodel.components.distributed.parallelizer.megatron_fsdp_strategy_parallelize(
model,
device_mesh: torch.distributed.device_mesh.DeviceMesh,
optimizer=None,
megatron_fsdp_unit_modules: Optional[List[str]] = None,
tp_shard_plan: Optional[Dict[str, Union[torch.distributed.tensor.parallel.RowwiseParallel, torch.distributed.tensor.parallel.ColwiseParallel, torch.distributed.tensor.parallel.SequenceParallel]]] = None,
zero_dp_strategy: int = 3,
init_fsdp_with_meta_device: bool = False,
grad_reduce_in_fp32: bool = False,
preserve_fp32_weights: bool = False,
overlap_grad_reduce: bool = True,
overlap_param_gather: bool = True,
check_for_nan_in_grad: bool = True,
average_in_collective: bool = False,
disable_bucketing: bool = False,
calculate_per_token_loss: bool = False,
keep_fp8_transpose_cache: bool = False,
nccl_ub: bool = False,
fsdp_double_buffer: bool = False,
dp_shard_dim: str = 'dp',
tp_dim: str = 'tp',
)#

Apply tensor/data parallelism (MegatronFSDP) and optional activation-checkpointing to the model.

Parameters:
  • model – The model to be parallelized.

  • device_mesh (DeviceMesh) – The device mesh describing the physical devices used for distributed training.

  • megatron_fsdp_unit_modules (Optional[List[str]]) – Names of sub-modules that should become individual MegatronFSDP units. If None, the full model is wrapped as a single unit.

  • tp_shard_plan (Optional[Dict[str, Union[RowwiseParallel, ColwiseParallel, SequenceParallel]]]) – A tensor-parallel sharding plan. Keys are module names; values specify the parallel style to apply (e.g., RowwiseParallel, ColwiseParallel, SequenceParallel).

  • zero_dp_strategy (int) – The zero-DP strategy to use.

  • init_fsdp_with_meta_device (bool) – If True, construct the model on a meta device first and materialize weights lazily to reduce memory fragmentation.

  • grad_reduce_in_fp32 (bool) – Reduce gradients in FP32 irrespective of the parameter precision to improve numerical stability.

  • preserve_fp32_weights (bool) – Keep a master FP32 copy of weights when training in reduced precision (e.g., FP16/BF16).

  • overlap_grad_reduce (bool) – If True, overlap gradient reduction with backward computation.

  • overlap_param_gather (bool) – If True, overlap parameter gathering with forward computation.

  • check_for_nan_in_grad (bool) – Whether to check gradients for NaNs/Infs before applying the optimizer step.

  • average_in_collective (bool) – Perform gradient averaging inside the collective operation instead of dividing afterward.

  • disable_bucketing (bool) – Disable gradient bucketing; gradients are reduced immediately as they are produced.

  • calculate_per_token_loss (bool) – Compute loss normalized by the number of tokens instead of the number of sequences.

  • keep_fp8_transpose_cache (bool) – Retain the FP8 transpose cache when using a custom MegatronFSDP wrapper.

  • nccl_ub (bool) – Enable NCCL user-buffer API (experimental) for reduced latency on some networks.

  • fsdp_double_buffer (bool) – Enable double buffering of parameters to overlap communication and computation in MegatronFSDP.

  • dp_shard_dim (str) – Key name for the data parallel mesh in device_mesh. Defaults to “dp”.

  • tp_dim (str) – Key name for the tensor parallel mesh in device_mesh. Defaults to “tp”.

NOTE: The passed-in model should preferably reside on the meta device. Otherwise, ensure the model fits into available GPU or CPU memory.

NOTE: The user must ensure that the provided tp_shard_plan is compatible with the model architecture.

nemo_automodel.components.distributed.parallelizer.unshard_fsdp2_model(
model: torch.nn.Module,
) Generator[None, None, None]#

Explicitly unshard and then reshard the FSDP2 modules. Useful for logprob inference.