nemo_automodel.components.distributed.parallelizer_utils

View as Markdown

Module Contents

Functions

NameDescription
_fully_shard-
_get_module_from_path-
_group_params_by_dtype-
_make_compute_dtype_fnBuild the per-parameter compute dtype resolver used to group FSDP units.
_mp_policy_with_param_dtype-
fully_shard_by_dtypeFully shard a module so every parameter computes in its required dtype.
iter_maximal_uniform_dtype_subtreesTraverse module and yield maximal submodules whose entire subtree has a unified dtype.

Data

UniformSubtreeItem

API

nemo_automodel.components.distributed.parallelizer_utils._fully_shard(
module: torch.nn.Module,
mesh: torch.distributed.device_mesh.DeviceMesh,
mp_policy: typing.Optional[torch.distributed.fsdp.MixedPrecisionPolicy],
offload_policy: typing.Optional[torch.distributed.fsdp.OffloadPolicy]
) -> None
nemo_automodel.components.distributed.parallelizer_utils._get_module_from_path(
layer: torch.nn.Module,
path: str
) -> torch.nn.Module
nemo_automodel.components.distributed.parallelizer_utils._group_params_by_dtype(
layer: torch.nn.Module,
dtype_of: typing.Optional[typing.Callable[[torch.Tensor], torch.dtype]] = None
) -> typing.Dict[torch.dtype, typing.List[torch.nn.Parameter]]
nemo_automodel.components.distributed.parallelizer_utils._make_compute_dtype_fn(
module: torch.nn.Module,
mp_policy: typing.Optional[torch.distributed.fsdp.MixedPrecisionPolicy],
fp32_compute_module_names: typing.Tuple[str, ...]
) -> typing.Callable[[torch.Tensor], torch.dtype]

Build the per-parameter compute dtype resolver used to group FSDP units.

The compute dtype of a floating tensor is resolved by precedence:

  1. Pinned fp32 — the tensor’s name matches fp32_compute_module_names (from the model’s _keep_in_fp32_modules_strict). Authoritative, works even from-scratch / quantized where there is no checkpoint to read.
  2. HF-recorded dtype — tensor._hf_compute_dtype, the checkpoint’s original dtype recorded at load time (see _restore_loaded_model_dtype). This makes any checkpoint-loaded model keep its intrinsically-fp32 params in fp32 compute automatically, even after storage was upcast for fp32 master weights.
  3. Fallback — when the tensor carries no compute hint, the result depends on whether the module’s floating-point storage is uniform:
    • uniform storage — mp_policy.param_dtype (the requested mixed-precision compute dtype, typically bf16). This is the fp32-master-weights case: the uniform-fp32 storage is artificially widened and should compute in the policy dtype. Falls back to the storage dtype when no policy is given.
    • mixed storage — the tensor’s own storage dtype. A param whose storage differs from its peers is intrinsically that dtype (not a master weight), so it must compute in it. Applying the policy here would force differently stored params into one compute dtype and re-introduce the mixed original dtype that stock FSDP2 rejects (_init_mp_dtypes).

Non-floating tensors always keep their storage dtype.

nemo_automodel.components.distributed.parallelizer_utils._mp_policy_with_param_dtype(
mp_policy: typing.Optional[torch.distributed.fsdp.MixedPrecisionPolicy],
param_dtype: torch.dtype
) -> typing.Optional[torch.distributed.fsdp.MixedPrecisionPolicy]
nemo_automodel.components.distributed.parallelizer_utils.fully_shard_by_dtype(
module: torch.nn.Module,
mesh: torch.distributed.device_mesh.DeviceMesh,
mp_policy: typing.Optional[torch.distributed.fsdp.MixedPrecisionPolicy],
offload_policy: typing.Optional[torch.distributed.fsdp.OffloadPolicy],
fp32_compute_module_names: typing.Tuple[str, ...] = ()
) -> None

Fully shard a module so every parameter computes in its required dtype.

The intent is simple: compute everything in mp_policy.param_dtype (e.g. bf16) except parameters that must stay in fp32 — their FSDP unit gets param_dtype=fp32 while the rest of the module computes in the policy dtype. A parameter “must stay fp32” if it is pinned via fp32_compute_module_names or HF stored it in fp32 (see _make_compute_dtype_fn for the full precedence). This decouples compute dtype from storage dtype, so fp32 master weights (uniform fp32 storage) still compute in bf16 for the bulk.

Implementation: group the module’s parameters by their resolved compute dtype and shard so each FSDP unit is compute-dtype-uniform. The three cases below differ only in sharding granularity:

  • 1 compute dtype -> shard the whole module once.
  • 2 compute dtypes -> shard the minority-dtype subtrees on their own, then shard the parent with the majority dtype (keeps the bulk as one FSDP unit).
  • 3+ compute dtypes -> shard every maximal compute-dtype-uniform subtree on its own.

Parameters:

fp32_compute_module_names
Tuple[str, ...]Defaults to ()

Parameter/buffer name substrings that must compute in fp32 (e.g. ("_fp32_params",) for Qwen3.5’s GatedDeltaNet fp32 holder). Sourced from the model’s _keep_in_fp32_modules_strict.

nemo_automodel.components.distributed.parallelizer_utils.iter_maximal_uniform_dtype_subtrees(
module: torch.nn.Module,
include_buffers: bool = True,
tensor_pred: typing.Optional[typing.Callable[[torch.Tensor], bool]] = None,
dtype_of: typing.Optional[typing.Callable[[torch.Tensor], torch.dtype]] = None,
return_paths: bool = False
) -> typing.Iterator[nemo_automodel.components.distributed.parallelizer_utils.UniformSubtreeItem]

Traverse module and yield maximal submodules whose entire subtree has a unified dtype.

  • include_buffers: include buffers in dtype unification checks.
  • tensor_pred: predicate to choose which tensors to consider (default: all). Example: tensor_pred=torch.is_floating_point (to consider only FP tensors)
  • dtype_of: maps a tensor to the dtype used for unification (default: its storage dtype t.dtype). Pass a custom function to group by compute dtype rather than storage dtype.
  • return_paths: if True, yields (qualified_name, module, dtype); else (module, dtype).

Notes:

  • If a module subtree has no tensors passing tensor_pred, it is ignored.
  • Maximality ensures no yielded module is a strict child of another yielded module.
nemo_automodel.components.distributed.parallelizer_utils.UniformSubtreeItem = Union[Tuple[nn.Module, torch.dtype], Tuple[str, nn.Module, torch.dtype]]