nemo_automodel.components.distributed.parallelizer

View as Markdown

Module Contents

Classes

NameDescription
DeepseekV4ParallelizationStrategyDeepSeek-V4 keeps a small set of reference-sensitive parameters in fp32.
DefaultParallelizationStrategyDefault parallelization strategy used by most models.
Gemma4ForConditionalGenerationPlaceholder when the installed transformers build has no Gemma4.
HunyuanParallelizationStrategyParallelization strategy for Hunyuan-style transformer modules used in HunyuanVideo.
NemotronHParallelizationStrategySpecialized parallelization strategy for NemotronH models.
ParallelizationStrategyAbstract base class for model parallelization strategies.
Qwen3_5ParallelizationStrategyParallelization strategy for Qwen3.5 dense models with mixed-dtype GatedDeltaNet.
WanParallelizationStrategyParallelization strategy for Wan-style transformer modules used in Diffusers.

Functions

NameDescription
_apply_bagel_full_layer_activation_checkpointingApply native BAGEL-style activation checkpointing to whole logical layers.
_apply_per_layer_compileCompile each decoder layer in-place after FSDP2 sharding.
_attention_is_head_shardedReturn True when the TP plan column-wise shards any QKV attention projection.
_extract_model_layersExtract layers from different model architectures for parallelization.
_find_largest_module_listHeuristic function to find the largest layer container in a model.
_get_module_by_fqn-
_get_parallel_planSelect the tensor-parallel plan for the given model.
_is_checkpoint_wrapped-
_is_transformers_v5_or_higherCheck if transformers version is 5.x or higher.
_nemotronh_decoder_blocksReturn (container, blocks) for a NemotronH model’s decoder blocks.
_patch_dtensor_spec_hash_for_symintFix a crash when torch.compile + DTensor are used together.
_subtree_all_frozenReturn True if module owns parameters and none of them require grad.
_update_attention_head_counts_for_tpAfter TP sharding, the Q/K/V outputs are split across ranks (each rank has
apply_fsdp2_sharding_recursivelyRecursively apply FSDP2 sharding to modules, with optimizations for ModuleList.
apply_selective_activation_checkpointingApply selective activation checkpointing to model end to end.
fsdp2_strategy_parallelizeApply parallelisms and activation checkpointing to the model.
get_hf_tp_shard_planGet the Hugging Face tensor parallel plan from the model.
get_parallelization_strategyGet the appropriate parallelization strategy for the given model.
import_class_from_pathImport a class from a string path (e.g. ‘torch.optim.AdamW’).
import_classes_from_pathsHelper function to import classes from string paths.
megatron_fsdp_strategy_parallelizeApply tensor/data parallelism (MegatronFSDP) and optional activation-checkpointing to the model.
register_parallel_strategyDecorator to register out-of-tree parallelism strategies.
translate_to_torch_parallel_styleTranslates string descriptions to parallelism plans.
unshard_fsdp2_modelExplicitly unshard and then reshard the FSDP2 modules. Useful for logprob inference.
validate_tp_meshValidate that attention heads and key value heads are divisible by TP size
validate_tp_mesh_for_nemotron_nasValidate that a Nemotron-NAS model can be tensor-parallel sharded.

Data

HAVE_MEGATRON_FSDP

PARALLELIZATION_STRATEGIES

_BAGEL_FULL_LAYER_CHECKPOINT_MODULE_LISTS

_DEFAULT_STRATEGY

logger

API

class nemo_automodel.components.distributed.parallelizer.DeepseekV4ParallelizationStrategy()

Bases: DefaultParallelizationStrategy

DeepSeek-V4 keeps a small set of reference-sensitive parameters in fp32.

nemo_automodel.components.distributed.parallelizer.DeepseekV4ParallelizationStrategy.parallelize(
model,
device_mesh,
dp_shard_cp_mesh_name = 'dp_shard_cp',
kwargs = {}
)
class nemo_automodel.components.distributed.parallelizer.DefaultParallelizationStrategy()

Bases: ParallelizationStrategy

Default parallelization strategy used by most models.

nemo_automodel.components.distributed.parallelizer.DefaultParallelizationStrategy.parallelize(
model: torch.nn.Module,
device_mesh: torch.distributed.device_mesh.DeviceMesh,
mp_policy: typing.Optional[torch.distributed.fsdp.MixedPrecisionPolicy] = None,
offload_policy: typing.Optional[torch.distributed.fsdp.OffloadPolicy] = None,
sequence_parallel: bool = False,
activation_checkpointing: bool = False,
tp_shard_plan: typing.Optional[typing.Union[typing.Dict[str, torch.distributed.tensor.parallel.ParallelStyle], str]] = None,
dp_replicate_mesh_name: str = 'dp_replicate',
dp_shard_cp_mesh_name: str = 'dp_shard_cp',
tp_mesh_name: str = 'tp',
enable_async_tensor_parallel: bool = False,
enable_compile: bool = False,
enable_fsdp2_prefetch: bool = True,
fsdp2_backward_prefetch_depth: int = 2,
fsdp2_forward_prefetch_depth: int = 1,
reshard_after_forward: typing.Optional[bool] = None,
fully_shard_fn = None
) -> torch.nn.Module

Apply the default parallelization flow.

class nemo_automodel.components.distributed.parallelizer.Gemma4ForConditionalGeneration()

Placeholder when the installed transformers build has no Gemma4.

class nemo_automodel.components.distributed.parallelizer.HunyuanParallelizationStrategy()

Bases: ParallelizationStrategy

Parallelization strategy for Hunyuan-style transformer modules used in HunyuanVideo.

nemo_automodel.components.distributed.parallelizer.HunyuanParallelizationStrategy.parallelize(
model: torch.nn.Module,
device_mesh: torch.distributed.device_mesh.DeviceMesh,
mp_policy: typing.Optional[torch.distributed.fsdp.MixedPrecisionPolicy] = None,
offload_policy: typing.Optional[torch.distributed.fsdp.OffloadPolicy] = None,
sequence_parallel: bool = False,
activation_checkpointing: bool = True,
tp_shard_plan: typing.Optional[typing.Union[typing.Dict[str, torch.distributed.tensor.parallel.ParallelStyle], str]] = None,
dp_replicate_mesh_name: str = 'dp_replicate',
dp_shard_cp_mesh_name: str = 'dp_shard_cp',
tp_mesh_name: str = 'tp',
kwargs = {}
) -> torch.nn.Module
class nemo_automodel.components.distributed.parallelizer.NemotronHParallelizationStrategy()

Bases: ParallelizationStrategy

Specialized parallelization strategy for NemotronH models.

nemo_automodel.components.distributed.parallelizer.NemotronHParallelizationStrategy.parallelize(
model: torch.nn.Module,
device_mesh: torch.distributed.device_mesh.DeviceMesh,
mp_policy: typing.Optional[torch.distributed.fsdp.MixedPrecisionPolicy] = None,
offload_policy: typing.Optional[torch.distributed.fsdp.OffloadPolicy] = None,
sequence_parallel: bool = False,
activation_checkpointing: bool = False,
tp_shard_plan: typing.Optional[typing.Union[typing.Dict[str, torch.distributed.tensor.parallel.ParallelStyle], str]] = None,
dp_replicate_mesh_name: str = 'dp_replicate',
dp_shard_cp_mesh_name: str = 'dp_shard_cp',
tp_mesh_name: str = 'tp',
kwargs = {}
) -> torch.nn.Module

Apply NemotronH-specific parallelization.

class nemo_automodel.components.distributed.parallelizer.ParallelizationStrategy()
Abstract

Abstract base class for model parallelization strategies.

nemo_automodel.components.distributed.parallelizer.ParallelizationStrategy.parallelize(
model: torch.nn.Module,
device_mesh: torch.distributed.device_mesh.DeviceMesh,
mp_policy: typing.Optional[torch.distributed.fsdp.MixedPrecisionPolicy] = None,
offload_policy: typing.Optional[torch.distributed.fsdp.OffloadPolicy] = None,
sequence_parallel: bool = False,
activation_checkpointing: bool = False,
tp_shard_plan: typing.Optional[typing.Union[typing.Dict[str, torch.distributed.tensor.parallel.ParallelStyle], str]] = None,
dp_replicate_mesh_name: str = 'dp_replicate',
dp_shard_cp_mesh_name: str = 'dp_shard_cp',
tp_mesh_name: str = 'tp',
kwargs = {}
) -> torch.nn.Module
abstract

Apply parallelization strategy to the model.

class nemo_automodel.components.distributed.parallelizer.Qwen3_5ParallelizationStrategy()

Bases: DefaultParallelizationStrategy

Parallelization strategy for Qwen3.5 dense models with mixed-dtype GatedDeltaNet.

Qwen3.5 has linear_attn layers with float32 params (A_log, norm) alongside bfloat16 params. Overrides the FSDP sharding step to use fully_shard_by_dtype per layer, and sets the CP mesh on CPAwareGatedDeltaNet modules.

nemo_automodel.components.distributed.parallelizer.Qwen3_5ParallelizationStrategy.parallelize(
model,
device_mesh,
dp_shard_cp_mesh_name = 'dp_shard_cp',
kwargs = {}
)
class nemo_automodel.components.distributed.parallelizer.WanParallelizationStrategy()

Bases: ParallelizationStrategy

Parallelization strategy for Wan-style transformer modules used in Diffusers.

Applies TP to condition embedders, FFN projections in each block, and final projection, then applies FSDP sharding similarly to other strategies.

nemo_automodel.components.distributed.parallelizer.WanParallelizationStrategy.parallelize(
model: torch.nn.Module,
device_mesh: torch.distributed.device_mesh.DeviceMesh,
mp_policy: typing.Optional[torch.distributed.fsdp.MixedPrecisionPolicy] = None,
offload_policy: typing.Optional[torch.distributed.fsdp.OffloadPolicy] = None,
sequence_parallel: bool = False,
activation_checkpointing: bool = False,
tp_shard_plan: typing.Optional[typing.Union[typing.Dict[str, torch.distributed.tensor.parallel.ParallelStyle], str]] = None,
dp_replicate_mesh_name: str = 'dp_replicate',
dp_shard_cp_mesh_name: str = 'dp_shard_cp',
tp_mesh_name: str = 'tp',
kwargs = {}
) -> torch.nn.Module
nemo_automodel.components.distributed.parallelizer._apply_bagel_full_layer_activation_checkpointing(
model: torch.nn.Module
) -> bool

Apply native BAGEL-style activation checkpointing to whole logical layers.

nemo_automodel.components.distributed.parallelizer._apply_per_layer_compile(
model: torch.nn.Module
) -> None

Compile each decoder layer in-place after FSDP2 sharding.

Compiles at decoder-layer granularity (not sub-module) so that AOT autograd traces the joint fwd+bwd graph under the training loop’s enable_grad context. Sub-module compile (e.g. on mlp alone) would be traced during activation checkpointing’s first forward pass which runs under no_grad, producing a forward-only graph that drops LoRA and other trainable-parameter gradients.

Prerequisite: NO_REENTRANT checkpoint_wrapper must already be applied to self_attn and mlp before FSDP2 sharding (done in DefaultParallelizationStrategy). This function only handles the compile step.

Whole-block selective-AC wrappers (tagged with SELECTIVE_AC_WRAPPER_FLAG) are compiled OUTER — the wrapper itself is compiled so the selective policy is traced and the partitioner honors its recompute tags. Other layer-level CheckpointWrappers (e.g. the PP path) are unwrapped and the decoder layer is compiled directly.

nn.Module.compile() is used instead of torch.compile() to compile in-place without introducing an _orig_mod wrapper, which would add a key prefix and break checkpoint loading.

_patch_dtensor_spec_hash_for_symint() is called to allow torch.compile with dynamic shapes to coexist with DTensor’s lru_cache-based sharding propagation.

nemo_automodel.components.distributed.parallelizer._attention_is_head_sharded(
model_parallel_plan: dict
) -> bool

Return True when the TP plan column-wise shards any QKV attention projection.

When Q/K/V projections use ColwiseParallel with sharded output (the default), each TP rank holds num_heads / tp_size heads and the model config / layer attributes must be updated accordingly.

Plans that keep attention replicated (e.g. Phi-3 with RowwiseParallel on fused QKV and Replicate output) should not trigger a head-count update.

nemo_automodel.components.distributed.parallelizer._extract_model_layers(
model: torch.nn.Module
) -> typing.List[torch.nn.Module]

Extract layers from different model architectures for parallelization.

This function handles various model types including vision-language models, causal language models, and multimodal models. It collects both language model layers and vision model layers where applicable.

Parameters:

model
nn.Module

The model to extract layers from.

Returns: List[nn.Module]

List[nn.Module]: A list of all layers that should be parallelized.

nemo_automodel.components.distributed.parallelizer._find_largest_module_list(
model: torch.nn.Module
) -> typing.Optional[typing.Union[torch.nn.ModuleList, torch.nn.ModuleDict]]

Heuristic function to find the largest layer container in a model.

This function recursively traverses the model to find all nn.ModuleList and pipeline-split nn.ModuleDict instances and returns the one with the most modules. This is useful as a fallback when the model architecture is unknown, since transformer layers are typically organized in ModuleLists. Pipeline splitting converts ModuleLists to ModuleDicts keyed by original layer index.

Parameters:

model
nn.Module

The model to search through.

Returns: Optional[Union[nn.ModuleList, nn.ModuleDict]]

Optional[Union[nn.ModuleList, nn.ModuleDict]]: The largest layer container found, or None.

nemo_automodel.components.distributed.parallelizer._get_module_by_fqn(
module: torch.nn.Module,
fqn: str
) -> typing.Optional[torch.nn.Module]
nemo_automodel.components.distributed.parallelizer._get_parallel_plan(
model: torch.nn.Module,
sequence_parallel: bool = False,
tp_shard_plan: typing.Optional[typing.Union[typing.Dict[str, torch.distributed.tensor.parallel.ParallelStyle], str]] = None,
tp_size: int = 1
) -> typing.Dict[str, torch.distributed.tensor.parallel.ParallelStyle]

Select the tensor-parallel plan for the given model.

Priority order:

  1. If tp_shard_plan is provided as a dict or import path, use it.
  2. If the model type exists in PARALLELIZE_FUNCTIONS, use its optimised plan; on failure, fall back to HF plan.
  3. Otherwise, prefer the model’s HF-native _tp_plan (via get_hf_tp_shard_plan).
  4. Otherwise, fall back to the default base plan.

When tp_size > 1 and the model falls through to path 4 and the model class was loaded from a custom-code source (HF’s trust_remote_code=True path, where the dynamic class lives under transformers_modules.*), this raises ValueError instead of returning the default base plan. On recent PyTorch the default plan’s placements do not populate shard_order and trip an internal assert in torch.distributed.tensor._redistribute on the first weight redistribute, which surfaces to the user as an opaque PyTorch internal error. Custom-code architectures are the only known-broken case (see https://github.com/NVIDIA-NeMo/Automodel/issues/2243); known HF architectures that happen to fall through (e.g. Mixtral) are left on the default plan with a warning, since they have been working in practice.

When the model did define a _tp_plan but get_hf_tp_shard_plan raised while translating it (e.g. styles nemo does not recognize), the translator’s error message is folded into the ValueError as a diagnostic so the user can tell whether to add a _tp_plan from scratch or fix the styles in the one they already have.

nemo_automodel.components.distributed.parallelizer._is_checkpoint_wrapped(
module: torch.nn.Module
) -> bool
nemo_automodel.components.distributed.parallelizer._is_transformers_v5_or_higher() -> bool

Check if transformers version is 5.x or higher.

nemo_automodel.components.distributed.parallelizer._nemotronh_decoder_blocks(
model: torch.nn.Module
) -> tuple[torch.nn.Module, list[torch.nn.Module]]

Return (container, blocks) for a NemotronH model’s decoder blocks.

Two distinct classes share the name NemotronHForCausalLM:

  • the HF model keeps its blocks in model.backbone.layers (an nn.ModuleList), while
  • the native Nemotron-V3 model (NemotronV3Model) keeps them in model.model.layers (an nn.ModuleDict keyed "0".."N-1").

container is the underlying ModuleList/ModuleDict (so callers can write rewrapped blocks back into the model), and blocks is the ordered list of block modules.

nemo_automodel.components.distributed.parallelizer._patch_dtensor_spec_hash_for_symint() -> None

Fix a crash when torch.compile + DTensor are used together.

Problem: torch.compile traces with symbolic shapes (SymInt). DTensorSpec hashes its shape to cache sharding decisions, but SymInt is not hashable -> crash.

Fix: if hashing the shape fails, fall back to hashing only (mesh, placements). Cache hits are slightly reduced but correctness is unaffected.

nemo_automodel.components.distributed.parallelizer._subtree_all_frozen(
module: torch.nn.Module
) -> bool

Return True if module owns parameters and none of them require grad.

Used to skip FSDP-wrapping a frozen submodule that never runs in the forward (e.g. the audio tower on image/text-only data); see apply_fsdp2_sharding_recursively.

nemo_automodel.components.distributed.parallelizer._update_attention_head_counts_for_tp(
model: torch.nn.Module,
tp_size: int
) -> None

After TP sharding, the Q/K/V outputs are split across ranks (each rank has num_heads/tp_size heads). Update the config and each attention layer’s num_heads / num_key_value_heads so the forward uses the local head count instead of the global one (avoids shape mismatches in .view()).

nemo_automodel.components.distributed.parallelizer.apply_fsdp2_sharding_recursively(
module: torch.nn.Module,
mesh: torch.distributed.device_mesh.DeviceMesh,
mp_policy: typing.Optional[torch.distributed.fsdp.MixedPrecisionPolicy],
offload_policy: typing.Optional[torch.distributed.fsdp.OffloadPolicy] = None,
enable_fsdp2_prefetch: bool = True,
fsdp2_backward_prefetch_depth: int = 2,
fsdp2_forward_prefetch_depth: int = 1,
reshard_after_forward: typing.Optional[bool] = None,
fully_shard_fn = None
) -> None

Recursively apply FSDP2 sharding to modules, with optimizations for ModuleList.

This utility function traverses a model hierarchy and applies FSDP2 sharding to each module. For ModuleList instances (commonly used for transformer layers), it applies an optimization where the last layer doesn’t reshard after forward since FSDP2 will prefetch it immediately.

Handles both single-level and nested ModuleList/ModuleDict structures. If a ModuleList contains other ModuleLists, it will recurse into them instead of trying to wrap them (since ModuleList doesn’t have a forward method).

Note: This function modifies the module in-place by replacing modules with their FSDP2-subclassed versions.

Parameters:

module
nn.Module

The module to apply FSDP sharding to.

mesh
DeviceMesh

The device mesh for FSDP sharding.

mp_policy
Optional[MixedPrecisionPolicy]

Mixed precision policy for FSDP.

offload_policy
Optional[OffloadPolicy]Defaults to None

CPU offload policy for FSDP. Defaults to None.

enable_fsdp2_prefetch
boolDefaults to True

Enable explicit forward/backward prefetch chains.

fsdp2_backward_prefetch_depth
intDefaults to 2

Backward prefetch depth.

fsdp2_forward_prefetch_depth
intDefaults to 1

Forward prefetch depth.

reshard_after_forward
Optional[bool]Defaults to None

Optional override for each layer’s fully_shard reshard behavior.

nemo_automodel.components.distributed.parallelizer.apply_selective_activation_checkpointing(
model: torch.nn.Module,
enable_compile: bool = False
) -> None

Apply selective activation checkpointing to model end to end.

Standalone entry point (detects KV-sharing, disables use_cache, and wraps transformer blocks) for paths where the FSDP2 parallelize flow is skipped — notably single-GPU training.

Parameters:

model
nn.Module

The model to checkpoint.

enable_compile
boolDefaults to False

Whether per-layer torch.compile will be applied.

nemo_automodel.components.distributed.parallelizer.fsdp2_strategy_parallelize(
model,
device_mesh: torch.distributed.device_mesh.DeviceMesh,
mp_policy: typing.Optional[torch.distributed.fsdp.MixedPrecisionPolicy] = None,
offload_policy: typing.Optional[torch.distributed.fsdp.OffloadPolicy] = None,
sequence_parallel: bool = False,
activation_checkpointing: bool = False,
tp_shard_plan: typing.Optional[typing.Union[typing.Dict[str, torch.distributed.tensor.parallel.ParallelStyle], str]] = None,
dp_replicate_mesh_name: str = 'dp_replicate',
dp_shard_cp_mesh_name: str = 'dp_shard_cp',
tp_mesh_name: str = 'tp',
enable_async_tensor_parallel: bool = False,
enable_compile: bool = False,
enable_fsdp2_prefetch: bool = True,
fsdp2_backward_prefetch_depth: int = 2,
fsdp2_forward_prefetch_depth: int = 1,
reshard_after_forward: typing.Optional[bool] = None
)

Apply parallelisms and activation checkpointing to the model.

Enhanced version that uses a strategy pattern for different model parallelization approaches:

  • Automatic strategy selection based on model type
  • Polymorphic parallelization strategies for different model families
  • Custom parallel plan support (dict or string path)
  • Sequence parallel support
  • Activation checkpointing for linear layers
  • Model validation (attention heads divisible by TP size)
  • Better fallback logic

NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory.

Parameters:

model

The model to be parallelized.

device_mesh
DeviceMesh

The device mesh for distributed training.

mp_policy
Optional[MixedPrecisionPolicy]Defaults to None

Mixed precision policy for model parallelism.

offload_policy
Optional[OffloadPolicy]Defaults to None

The offload policy for FSDP.

sequence_parallel
boolDefaults to False

Whether to use sequence parallelism. Defaults to False.

activation_checkpointing
boolDefaults to False

Whether to use activation checkpointing. Defaults to False.

tp_shard_plan
Optional[Union[Dict[str, ParallelStyle], str]]Defaults to None

Custom tensor parallel plan for the model. Can be:

  • A dictionary mapping module names to parallel styles
  • A string path to a dictionary or function that returns a dictionary If provided, this takes precedence over automatic plan generation.
dp_replicate_mesh_name
strDefaults to 'dp_replicate'

Key name for the data parallel replicate mesh in device_mesh. Used when data parallel replicate is enabled. Defaults to “dp_replicate”.

dp_shard_cp_mesh_name
strDefaults to 'dp_shard_cp'

Key name for the data parallel shard + context parallel mesh in device_mesh. Used when data parallel shard is enabled. Defaults to “dp_shard_cp”.

tp_mesh_name
strDefaults to 'tp'

Key name for the tensor parallel mesh in device_mesh. Defaults to “tp”.

Returns:

The parallelized model.

nemo_automodel.components.distributed.parallelizer.get_hf_tp_shard_plan(
model
)

Get the Hugging Face tensor parallel plan from the model.

This function:

  • Retrieves TP strategies from model class, instance, and inner model levels.
  • Handles special cases for embed_tokens and lm_head for speed up.
  • Converts string-based parallel styles to DTensor parallelization strategies.

Taken and modified from: https://github.com/NVIDIA/NeMo/blob/6c6169db01bcca73ae8ad3ac35242fadbb9a78ba/nemo/lightning/pytorch/strategies/utils.py#L532

Parameters:

model

A Hugging Face model instance

Returns:

A dictionary mapping model component paths to their parallelization strategies

Raises:

  • AssertionError: If no TP plan is found
nemo_automodel.components.distributed.parallelizer.get_parallelization_strategy(
model: torch.nn.Module
) -> nemo_automodel.components.distributed.parallelizer.ParallelizationStrategy

Get the appropriate parallelization strategy for the given model.

nemo_automodel.components.distributed.parallelizer.import_class_from_path(
name: str
) -> typing.Any

Import a class from a string path (e.g. ‘torch.optim.AdamW’).

Parameters:

full_path

Full path to class including module path and class name

Returns: Any

The imported class object

nemo_automodel.components.distributed.parallelizer.import_classes_from_paths(
class_paths: typing.List[str]
)

Helper function to import classes from string paths.

Parameters:

class_paths
List[str]

The list of string paths to the classes.

Returns:

List of imported classes.

nemo_automodel.components.distributed.parallelizer.megatron_fsdp_strategy_parallelize(
model,
device_mesh: torch.distributed.device_mesh.DeviceMesh,
optimizer = None,
megatron_fsdp_unit_modules: typing.Optional[typing.List[str]] = None,
tp_shard_plan: typing.Optional[typing.Dict[str, typing.Union[torch.distributed.tensor.parallel.RowwiseParallel, torch.distributed.tensor.parallel.ColwiseParallel, torch.distributed.tensor.parallel.SequenceParallel]]] = None,
zero_dp_strategy: int = 3,
init_fsdp_with_meta_device: bool = False,
grad_reduce_in_fp32: bool = False,
preserve_fp32_weights: bool = False,
overlap_grad_reduce: bool = True,
overlap_param_gather: bool = True,
check_for_nan_in_grad: bool = True,
average_in_collective: bool = False,
disable_bucketing: bool = False,
calculate_per_token_loss: bool = False,
keep_fp8_transpose_cache: bool = False,
nccl_ub: bool = False,
fsdp_double_buffer: bool = False,
dp_shard_dim: str = 'dp',
tp_dim: str = 'tp'
)

Apply tensor/data parallelism (MegatronFSDP) and optional activation-checkpointing to the model.

NOTE: The passed-in model should preferably reside on the meta device. Otherwise, ensure the model fits into available GPU or CPU memory.

NOTE: The user must ensure that the provided tp_shard_plan is compatible with the model architecture.

Parameters:

model

The model to be parallelized.

device_mesh
DeviceMesh

The device mesh describing the physical devices used for distributed training.

megatron_fsdp_unit_modules
Optional[List[str]]Defaults to None

Names of sub-modules that should become individual MegatronFSDP units. If None, the full model is wrapped as a single unit.

tp_shard_plan
Optional[Dict[str, Union[RowwiseParallel, ColwiseParallel, SequenceParallel]]]Defaults to None

A tensor-parallel sharding plan. Keys are module names; values specify the parallel style to apply (e.g., RowwiseParallel, ColwiseParallel, SequenceParallel).

zero_dp_strategy
intDefaults to 3

The zero-DP strategy to use.

init_fsdp_with_meta_device
boolDefaults to False

If True, construct the model on a meta device first and materialize weights lazily to reduce memory fragmentation.

grad_reduce_in_fp32
boolDefaults to False

Reduce gradients in FP32 irrespective of the parameter precision to improve numerical stability.

preserve_fp32_weights
boolDefaults to False

Keep a master FP32 copy of weights when training in reduced precision (e.g., FP16/BF16).

overlap_grad_reduce
boolDefaults to True

If True, overlap gradient reduction with backward computation.

overlap_param_gather
boolDefaults to True

If True, overlap parameter gathering with forward computation.

check_for_nan_in_grad
boolDefaults to True

Whether to check gradients for NaNs/Infs before applying the optimizer step.

average_in_collective
boolDefaults to False

Perform gradient averaging inside the collective operation instead of dividing afterward.

disable_bucketing
boolDefaults to False

Disable gradient bucketing; gradients are reduced immediately as they are produced.

calculate_per_token_loss
boolDefaults to False

Compute loss normalized by the number of tokens instead of the number of sequences.

keep_fp8_transpose_cache
boolDefaults to False

Retain the FP8 transpose cache when using a custom MegatronFSDP wrapper.

nccl_ub
boolDefaults to False

Enable NCCL user-buffer API (experimental) for reduced latency on some networks.

fsdp_double_buffer
boolDefaults to False

Enable double buffering of parameters to overlap communication and computation in MegatronFSDP.

dp_shard_dim
strDefaults to 'dp'

Key name for the data parallel mesh in device_mesh. Defaults to “dp”.

tp_dim
strDefaults to 'tp'

Key name for the tensor parallel mesh in device_mesh. Defaults to “tp”.

nemo_automodel.components.distributed.parallelizer.register_parallel_strategy(
arg = None,
name: typing.Optional[str] = None
)

Decorator to register out-of-tree parallelism strategies.

Supports:

  • @register_parallel_strategy(name=“CustomModelName”)
nemo_automodel.components.distributed.parallelizer.translate_to_torch_parallel_style(
style: str
)

Translates string descriptions to parallelism plans.

In model configurations, we use a neutral type (string) to specify parallel styles, here we translate them into torch.distributed tensor-parallel types.

nemo_automodel.components.distributed.parallelizer.unshard_fsdp2_model(
model: torch.nn.Module
) -> typing.Generator[None, None, None]

Explicitly unshard and then reshard the FSDP2 modules. Useful for logprob inference.

nemo_automodel.components.distributed.parallelizer.validate_tp_mesh(
model,
tp_mesh
)

Validate that attention heads and key value heads are divisible by TP size

nemo_automodel.components.distributed.parallelizer.validate_tp_mesh_for_nemotron_nas(
model,
tp_size
)

Validate that a Nemotron-NAS model can be tensor-parallel sharded.

nemo_automodel.components.distributed.parallelizer.HAVE_MEGATRON_FSDP = True
nemo_automodel.components.distributed.parallelizer.PARALLELIZATION_STRATEGIES: Dict[str, ParallelizationStrategy] = {'NemotronHForCausalLM': NemotronHParallelizationStrategy(), 'DeepseekV4ForCausa...
nemo_automodel.components.distributed.parallelizer._BAGEL_FULL_LAYER_CHECKPOINT_MODULE_LISTS = ('model.language_model.model.layers', 'model.vit_model.vision_model.encoder.laye...
nemo_automodel.components.distributed.parallelizer._DEFAULT_STRATEGY = DefaultParallelizationStrategy()
nemo_automodel.components.distributed.parallelizer.logger = logging.getLogger(__name__)