> For clean Markdown of any page, append .md to the page URL.
> For a complete documentation index, see https://docs.nvidia.com/nemo/automodel/llms.txt.
> For AI client integration (Claude Code, Cursor, etc.), connect to the MCP server at https://docs.nvidia.com/nemo/automodel/_mcp/server.

# nemo_automodel.components.distributed.parallelizer_utils

## Module Contents

### Functions

| Name                                                                                                                                   | Description                                                                              |
| -------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------- |
| [`_fully_shard`](#nemo_automodel-components-distributed-parallelizer_utils-_fully_shard)                                               | -                                                                                        |
| [`_get_module_from_path`](#nemo_automodel-components-distributed-parallelizer_utils-_get_module_from_path)                             | -                                                                                        |
| [`_group_params_by_dtype`](#nemo_automodel-components-distributed-parallelizer_utils-_group_params_by_dtype)                           | -                                                                                        |
| [`_make_compute_dtype_fn`](#nemo_automodel-components-distributed-parallelizer_utils-_make_compute_dtype_fn)                           | Build the per-parameter *compute* dtype resolver used to group FSDP units.               |
| [`_mp_policy_with_param_dtype`](#nemo_automodel-components-distributed-parallelizer_utils-_mp_policy_with_param_dtype)                 | -                                                                                        |
| [`fully_shard_by_dtype`](#nemo_automodel-components-distributed-parallelizer_utils-fully_shard_by_dtype)                               | Fully shard a module so every parameter computes in its required dtype.                  |
| [`iter_maximal_uniform_dtype_subtrees`](#nemo_automodel-components-distributed-parallelizer_utils-iter_maximal_uniform_dtype_subtrees) | Traverse `module` and yield maximal submodules whose entire subtree has a unified dtype. |

### Data

[`UniformSubtreeItem`](#nemo_automodel-components-distributed-parallelizer_utils-UniformSubtreeItem)

### API

```python
nemo_automodel.components.distributed.parallelizer_utils._fully_shard(
    module: torch.nn.Module,
    mesh: torch.distributed.device_mesh.DeviceMesh,
    mp_policy: typing.Optional[torch.distributed.fsdp.MixedPrecisionPolicy],
    offload_policy: typing.Optional[torch.distributed.fsdp.OffloadPolicy]
) -> None
```

```python
nemo_automodel.components.distributed.parallelizer_utils._get_module_from_path(
    layer: torch.nn.Module,
    path: str
) -> torch.nn.Module
```

```python
nemo_automodel.components.distributed.parallelizer_utils._group_params_by_dtype(
    layer: torch.nn.Module,
    dtype_of: typing.Optional[typing.Callable[[torch.Tensor], torch.dtype]] = None
) -> typing.Dict[torch.dtype, typing.List[torch.nn.Parameter]]
```

```python
nemo_automodel.components.distributed.parallelizer_utils._make_compute_dtype_fn(
    module: torch.nn.Module,
    mp_policy: typing.Optional[torch.distributed.fsdp.MixedPrecisionPolicy],
    fp32_compute_module_names: typing.Tuple[str, ...]
) -> typing.Callable[[torch.Tensor], torch.dtype]
```

Build the per-parameter *compute* dtype resolver used to group FSDP units.

The compute dtype of a floating tensor is resolved by precedence:

1. Pinned fp32 -- the tensor's name matches `fp32_compute_module_names`
   (from the model's `_keep_in_fp32_modules_strict`). Authoritative, works
   even from-scratch / quantized where there is no checkpoint to read.
2. HF-recorded dtype -- `tensor._hf_compute_dtype`, the checkpoint's original
   dtype recorded at load time (see `_restore_loaded_model_dtype`). This makes
   any checkpoint-loaded model keep its intrinsically-fp32 params in fp32 compute
   automatically, even after storage was upcast for fp32 master weights.
3. Fallback -- when the tensor carries no compute hint, the result depends on
   whether the module's floating-point *storage* is uniform:
   * uniform storage -- `mp_policy.param_dtype` (the requested mixed-precision
     compute dtype, typically bf16). This is the fp32-master-weights case: the
     uniform-fp32 storage is artificially widened and should compute in the
     policy dtype. Falls back to the storage dtype when no policy is given.
   * mixed storage -- the tensor's own storage dtype. A param whose storage
     differs from its peers is intrinsically that dtype (not a master weight),
     so it must compute in it. Applying the policy here would force differently
     stored params into one compute dtype and re-introduce the mixed *original*
     dtype that stock FSDP2 rejects (`_init_mp_dtypes`).

Non-floating tensors always keep their storage dtype.

```python
nemo_automodel.components.distributed.parallelizer_utils._mp_policy_with_param_dtype(
    mp_policy: typing.Optional[torch.distributed.fsdp.MixedPrecisionPolicy],
    param_dtype: torch.dtype
) -> typing.Optional[torch.distributed.fsdp.MixedPrecisionPolicy]
```

```python
nemo_automodel.components.distributed.parallelizer_utils.fully_shard_by_dtype(
    module: torch.nn.Module,
    mesh: torch.distributed.device_mesh.DeviceMesh,
    mp_policy: typing.Optional[torch.distributed.fsdp.MixedPrecisionPolicy],
    offload_policy: typing.Optional[torch.distributed.fsdp.OffloadPolicy],
    fp32_compute_module_names: typing.Tuple[str, ...] = ()
) -> None
```

Fully shard a module so every parameter computes in its required dtype.

The intent is simple: compute everything in `mp_policy.param_dtype` (e.g. bf16)
except parameters that must stay in fp32 -- their FSDP unit gets `param_dtype=fp32`
while the rest of the module computes in the policy dtype. A parameter "must stay
fp32" if it is pinned via `fp32_compute_module_names` or HF stored it in fp32 (see
`_make_compute_dtype_fn` for the full precedence). This decouples *compute* dtype
from *storage* dtype, so fp32 master weights (uniform fp32 storage) still compute in
bf16 for the bulk.

Implementation: group the module's parameters by their resolved compute dtype and
shard so each FSDP unit is compute-dtype-uniform. The three cases below differ only
in sharding granularity:

* 1 compute dtype  -> shard the whole module once.
* 2 compute dtypes -> shard the minority-dtype subtrees on their own, then shard
  the parent with the majority dtype (keeps the bulk as one FSDP unit).
* 3+ compute dtypes -> shard every maximal compute-dtype-uniform subtree on its own.

**Parameters:**

Parameter/buffer name substrings that must compute in
fp32 (e.g. `("_fp32_params",)` for Qwen3.5's GatedDeltaNet fp32 holder).
Sourced from the model's `_keep_in_fp32_modules_strict`.

```python
nemo_automodel.components.distributed.parallelizer_utils.iter_maximal_uniform_dtype_subtrees(
    module: torch.nn.Module,
    include_buffers: bool = True,
    tensor_pred: typing.Optional[typing.Callable[[torch.Tensor], bool]] = None,
    dtype_of: typing.Optional[typing.Callable[[torch.Tensor], torch.dtype]] = None,
    return_paths: bool = False
) -> typing.Iterator[nemo_automodel.components.distributed.parallelizer_utils.UniformSubtreeItem]
```

Traverse `module` and yield maximal submodules whose entire subtree has a unified dtype.

* include\_buffers: include buffers in dtype unification checks.
* tensor\_pred: predicate to choose which tensors to consider (default: all).
  Example: tensor\_pred=torch.is\_floating\_point  (to consider only FP tensors)
* dtype\_of: maps a tensor to the dtype used for unification (default: its storage
  dtype `t.dtype`). Pass a custom function to group by *compute* dtype
  rather than storage dtype.
* return\_paths: if True, yields (qualified\_name, module, dtype); else (module, dtype).

Notes:

* If a module subtree has no tensors passing `tensor_pred`, it is ignored.
* Maximality ensures no yielded module is a strict child of another yielded module.

```python
nemo_automodel.components.distributed.parallelizer_utils.UniformSubtreeItem = Union[Tuple[nn.Module, torch.dtype], Tuple[str, nn.Module, torch.dtype]]
```