nemo_automodel.components.distributed.parallelizer_utils
nemo_automodel.components.distributed.parallelizer_utils
Module Contents
Functions
Data
API
Build the per-parameter compute dtype resolver used to group FSDP units.
The compute dtype of a floating tensor is resolved by precedence:
- Pinned fp32 — the tensor’s name matches
fp32_compute_module_names(from the model’s_keep_in_fp32_modules_strict). Authoritative, works even from-scratch / quantized where there is no checkpoint to read. - HF-recorded dtype —
tensor._hf_compute_dtype, the checkpoint’s original dtype recorded at load time (see_restore_loaded_model_dtype). This makes any checkpoint-loaded model keep its intrinsically-fp32 params in fp32 compute automatically, even after storage was upcast for fp32 master weights. - Fallback — when the tensor carries no compute hint, the result depends on
whether the module’s floating-point storage is uniform:
- uniform storage —
mp_policy.param_dtype(the requested mixed-precision compute dtype, typically bf16). This is the fp32-master-weights case: the uniform-fp32 storage is artificially widened and should compute in the policy dtype. Falls back to the storage dtype when no policy is given. - mixed storage — the tensor’s own storage dtype. A param whose storage
differs from its peers is intrinsically that dtype (not a master weight),
so it must compute in it. Applying the policy here would force differently
stored params into one compute dtype and re-introduce the mixed original
dtype that stock FSDP2 rejects (
_init_mp_dtypes).
- uniform storage —
Non-floating tensors always keep their storage dtype.
Fully shard a module so every parameter computes in its required dtype.
The intent is simple: compute everything in mp_policy.param_dtype (e.g. bf16)
except parameters that must stay in fp32 — their FSDP unit gets param_dtype=fp32
while the rest of the module computes in the policy dtype. A parameter “must stay
fp32” if it is pinned via fp32_compute_module_names or HF stored it in fp32 (see
_make_compute_dtype_fn for the full precedence). This decouples compute dtype
from storage dtype, so fp32 master weights (uniform fp32 storage) still compute in
bf16 for the bulk.
Implementation: group the module’s parameters by their resolved compute dtype and shard so each FSDP unit is compute-dtype-uniform. The three cases below differ only in sharding granularity:
- 1 compute dtype -> shard the whole module once.
- 2 compute dtypes -> shard the minority-dtype subtrees on their own, then shard the parent with the majority dtype (keeps the bulk as one FSDP unit).
- 3+ compute dtypes -> shard every maximal compute-dtype-uniform subtree on its own.
Parameters:
Parameter/buffer name substrings that must compute in
fp32 (e.g. ("_fp32_params",) for Qwen3.5’s GatedDeltaNet fp32 holder).
Sourced from the model’s _keep_in_fp32_modules_strict.
Traverse module and yield maximal submodules whose entire subtree has a unified dtype.
- include_buffers: include buffers in dtype unification checks.
- tensor_pred: predicate to choose which tensors to consider (default: all). Example: tensor_pred=torch.is_floating_point (to consider only FP tensors)
- dtype_of: maps a tensor to the dtype used for unification (default: its storage
dtype
t.dtype). Pass a custom function to group by compute dtype rather than storage dtype. - return_paths: if True, yields (qualified_name, module, dtype); else (module, dtype).
Notes:
- If a module subtree has no tensors passing
tensor_pred, it is ignored. - Maximality ensures no yielded module is a strict child of another yielded module.