nemo_automodel.components.distributed.parallelizer
nemo_automodel.components.distributed.parallelizer
Module Contents
Classes
Functions
Data
_BAGEL_FULL_LAYER_CHECKPOINT_MODULE_LISTS
API
Bases: DefaultParallelizationStrategy
DeepSeek-V4 keeps a small set of reference-sensitive parameters in fp32.
Bases: ParallelizationStrategy
Default parallelization strategy used by most models.
Apply the default parallelization flow.
Placeholder when the installed transformers build has no Gemma4.
Bases: ParallelizationStrategy
Parallelization strategy for Hunyuan-style transformer modules used in HunyuanVideo.
Bases: ParallelizationStrategy
Specialized parallelization strategy for NemotronH models.
Apply NemotronH-specific parallelization.
Abstract base class for model parallelization strategies.
Apply parallelization strategy to the model.
Bases: DefaultParallelizationStrategy
Parallelization strategy for Qwen3.5 dense models with mixed-dtype GatedDeltaNet.
Qwen3.5 has linear_attn layers with float32 params (A_log, norm) alongside bfloat16 params. Overrides the FSDP sharding step to use fully_shard_by_dtype per layer, and sets the CP mesh on CPAwareGatedDeltaNet modules.
Bases: ParallelizationStrategy
Parallelization strategy for Wan-style transformer modules used in Diffusers.
Applies TP to condition embedders, FFN projections in each block, and final projection, then applies FSDP sharding similarly to other strategies.
Apply native BAGEL-style activation checkpointing to whole logical layers.
Compile each decoder layer in-place after FSDP2 sharding.
Compiles at decoder-layer granularity (not sub-module) so that AOT autograd traces the joint fwd+bwd graph under the training loop’s enable_grad context. Sub-module compile (e.g. on mlp alone) would be traced during activation checkpointing’s first forward pass which runs under no_grad, producing a forward-only graph that drops LoRA and other trainable-parameter gradients.
Prerequisite: NO_REENTRANT checkpoint_wrapper must already be applied to self_attn and mlp before FSDP2 sharding (done in DefaultParallelizationStrategy). This function only handles the compile step.
Whole-block selective-AC wrappers (tagged with SELECTIVE_AC_WRAPPER_FLAG)
are compiled OUTER — the wrapper itself is compiled so the selective policy
is traced and the partitioner honors its recompute tags. Other layer-level
CheckpointWrappers (e.g. the PP path) are unwrapped and the decoder layer is
compiled directly.
nn.Module.compile() is used instead of torch.compile() to compile in-place without introducing an _orig_mod wrapper, which would add a key prefix and break checkpoint loading.
_patch_dtensor_spec_hash_for_symint() is called to allow torch.compile with dynamic shapes to coexist with DTensor’s lru_cache-based sharding propagation.
Return True when the TP plan column-wise shards any QKV attention projection.
When Q/K/V projections use ColwiseParallel with sharded output (the
default), each TP rank holds num_heads / tp_size heads and the model
config / layer attributes must be updated accordingly.
Plans that keep attention replicated (e.g. Phi-3 with RowwiseParallel
on fused QKV and Replicate output) should not trigger a head-count
update.
Extract layers from different model architectures for parallelization.
This function handles various model types including vision-language models, causal language models, and multimodal models. It collects both language model layers and vision model layers where applicable.
Parameters:
The model to extract layers from.
Returns: List[nn.Module]
List[nn.Module]: A list of all layers that should be parallelized.
Heuristic function to find the largest layer container in a model.
This function recursively traverses the model to find all nn.ModuleList and pipeline-split nn.ModuleDict instances and returns the one with the most modules. This is useful as a fallback when the model architecture is unknown, since transformer layers are typically organized in ModuleLists. Pipeline splitting converts ModuleLists to ModuleDicts keyed by original layer index.
Parameters:
The model to search through.
Returns: Optional[Union[nn.ModuleList, nn.ModuleDict]]
Optional[Union[nn.ModuleList, nn.ModuleDict]]: The largest layer container found, or None.
Select the tensor-parallel plan for the given model.
Priority order:
- If
tp_shard_planis provided as a dict or import path, use it. - If the model type exists in
PARALLELIZE_FUNCTIONS, use its optimised plan; on failure, fall back to HF plan. - Otherwise, prefer the model’s HF-native
_tp_plan(viaget_hf_tp_shard_plan). - Otherwise, fall back to the default base plan.
When tp_size > 1 and the model falls through to path 4 and the
model class was loaded from a custom-code source (HF’s
trust_remote_code=True path, where the dynamic class lives under
transformers_modules.*), this raises ValueError instead of
returning the default base plan. On recent PyTorch the default plan’s
placements do not populate shard_order and trip an internal assert in
torch.distributed.tensor._redistribute on the first weight
redistribute, which surfaces to the user as an opaque PyTorch internal
error. Custom-code architectures are the only known-broken case (see
https://github.com/NVIDIA-NeMo/Automodel/issues/2243); known HF
architectures that happen to fall through (e.g. Mixtral) are left on the
default plan with a warning, since they have been working in practice.
When the model did define a _tp_plan but get_hf_tp_shard_plan
raised while translating it (e.g. styles nemo does not recognize), the
translator’s error message is folded into the ValueError as a
diagnostic so the user can tell whether to add a _tp_plan from
scratch or fix the styles in the one they already have.
Check if transformers version is 5.x or higher.
Return (container, blocks) for a NemotronH model’s decoder blocks.
Two distinct classes share the name NemotronHForCausalLM:
- the HF model keeps its blocks in
model.backbone.layers(annn.ModuleList), while - the native Nemotron-V3 model (
NemotronV3Model) keeps them inmodel.model.layers(annn.ModuleDictkeyed"0".."N-1").
container is the underlying ModuleList/ModuleDict (so callers can write rewrapped
blocks back into the model), and blocks is the ordered list of block modules.
Fix a crash when torch.compile + DTensor are used together.
Problem: torch.compile traces with symbolic shapes (SymInt). DTensorSpec hashes its shape to cache sharding decisions, but SymInt is not hashable -> crash.
Fix: if hashing the shape fails, fall back to hashing only (mesh, placements). Cache hits are slightly reduced but correctness is unaffected.
Return True if module owns parameters and none of them require grad.
Used to skip FSDP-wrapping a frozen submodule that never runs in the forward
(e.g. the audio tower on image/text-only data); see
apply_fsdp2_sharding_recursively.
After TP sharding, the Q/K/V outputs are split across ranks (each rank has num_heads/tp_size heads). Update the config and each attention layer’s num_heads / num_key_value_heads so the forward uses the local head count instead of the global one (avoids shape mismatches in .view()).
Recursively apply FSDP2 sharding to modules, with optimizations for ModuleList.
This utility function traverses a model hierarchy and applies FSDP2 sharding to each module. For ModuleList instances (commonly used for transformer layers), it applies an optimization where the last layer doesn’t reshard after forward since FSDP2 will prefetch it immediately.
Handles both single-level and nested ModuleList/ModuleDict structures. If a ModuleList contains other ModuleLists, it will recurse into them instead of trying to wrap them (since ModuleList doesn’t have a forward method).
Note: This function modifies the module in-place by replacing modules with their FSDP2-subclassed versions.
Parameters:
The module to apply FSDP sharding to.
The device mesh for FSDP sharding.
Mixed precision policy for FSDP.
CPU offload policy for FSDP. Defaults to None.
Enable explicit forward/backward prefetch chains.
Backward prefetch depth.
Forward prefetch depth.
Optional override for each layer’s
fully_shard reshard behavior.
Apply selective activation checkpointing to model end to end.
Standalone entry point (detects KV-sharing, disables use_cache, and
wraps transformer blocks) for paths where the FSDP2 parallelize flow is
skipped — notably single-GPU training.
Parameters:
The model to checkpoint.
Whether per-layer torch.compile will be applied.
Apply parallelisms and activation checkpointing to the model.
Enhanced version that uses a strategy pattern for different model parallelization approaches:
- Automatic strategy selection based on model type
- Polymorphic parallelization strategies for different model families
- Custom parallel plan support (dict or string path)
- Sequence parallel support
- Activation checkpointing for linear layers
- Model validation (attention heads divisible by TP size)
- Better fallback logic
NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory.
Parameters:
The model to be parallelized.
The device mesh for distributed training.
Mixed precision policy for model parallelism.
The offload policy for FSDP.
Whether to use sequence parallelism. Defaults to False.
Whether to use activation checkpointing. Defaults to False.
Custom tensor parallel plan for the model. Can be:
- A dictionary mapping module names to parallel styles
- A string path to a dictionary or function that returns a dictionary If provided, this takes precedence over automatic plan generation.
Key name for the data parallel replicate mesh in device_mesh. Used when data parallel replicate is enabled. Defaults to “dp_replicate”.
Key name for the data parallel shard + context parallel mesh in device_mesh. Used when data parallel shard is enabled. Defaults to “dp_shard_cp”.
Key name for the tensor parallel mesh in device_mesh. Defaults to “tp”.
Returns:
The parallelized model.
Get the Hugging Face tensor parallel plan from the model.
This function:
- Retrieves TP strategies from model class, instance, and inner model levels.
- Handles special cases for
embed_tokensandlm_headfor speed up. - Converts string-based parallel styles to DTensor parallelization strategies.
Taken and modified from: https://github.com/NVIDIA/NeMo/blob/6c6169db01bcca73ae8ad3ac35242fadbb9a78ba/nemo/lightning/pytorch/strategies/utils.py#L532
Parameters:
A Hugging Face model instance
Returns:
A dictionary mapping model component paths to their parallelization strategies
Raises:
AssertionError: If no TP plan is found
Get the appropriate parallelization strategy for the given model.
Import a class from a string path (e.g. ‘torch.optim.AdamW’).
Parameters:
Full path to class including module path and class name
Returns: Any
The imported class object
Helper function to import classes from string paths.
Parameters:
The list of string paths to the classes.
Returns:
List of imported classes.
Apply tensor/data parallelism (MegatronFSDP) and optional activation-checkpointing to the model.
NOTE: The passed-in model should preferably reside on the meta device. Otherwise, ensure the model fits into available GPU or CPU memory.
NOTE: The user must ensure that the provided tp_shard_plan is compatible with the model architecture.
Parameters:
The model to be parallelized.
The device mesh describing the physical devices used for distributed training.
Names of sub-modules that should become individual MegatronFSDP units. If None, the full model is wrapped as a single unit.
A tensor-parallel sharding plan. Keys are module names; values specify the parallel style to apply (e.g., RowwiseParallel, ColwiseParallel, SequenceParallel).
The zero-DP strategy to use.
If True, construct the model on a meta device first and materialize weights lazily to reduce memory fragmentation.
Reduce gradients in FP32 irrespective of the parameter precision to improve numerical stability.
Keep a master FP32 copy of weights when training in reduced precision (e.g., FP16/BF16).
If True, overlap gradient reduction with backward computation.
If True, overlap parameter gathering with forward computation.
Whether to check gradients for NaNs/Infs before applying the optimizer step.
Perform gradient averaging inside the collective operation instead of dividing afterward.
Disable gradient bucketing; gradients are reduced immediately as they are produced.
Compute loss normalized by the number of tokens instead of the number of sequences.
Retain the FP8 transpose cache when using a custom MegatronFSDP wrapper.
Enable NCCL user-buffer API (experimental) for reduced latency on some networks.
Enable double buffering of parameters to overlap communication and computation in MegatronFSDP.
Key name for the data parallel mesh in device_mesh. Defaults to “dp”.
Key name for the tensor parallel mesh in device_mesh. Defaults to “tp”.
Decorator to register out-of-tree parallelism strategies.
Supports:
- @register_parallel_strategy(name=“CustomModelName”)
Translates string descriptions to parallelism plans.
In model configurations, we use a neutral type (string) to specify parallel styles, here we translate them into torch.distributed tensor-parallel types.
Explicitly unshard and then reshard the FSDP2 modules. Useful for logprob inference.
Validate that attention heads and key value heads are divisible by TP size
Validate that a Nemotron-NAS model can be tensor-parallel sharded.