nemo_automodel.components.distributed.parallelizer_utils#
Module Contents#
Functions#
Traverse |
|
Fully shard a module, splitting mixed-dtype subtrees when needed. |
Data#
API#
- nemo_automodel.components.distributed.parallelizer_utils.UniformSubtreeItem#
None
- nemo_automodel.components.distributed.parallelizer_utils.iter_maximal_uniform_dtype_subtrees(
- module: torch.nn.Module,
- *,
- include_buffers: bool = True,
- tensor_pred: Optional[Callable[[torch.Tensor], bool]] = None,
- return_paths: bool = False,
Traverse
moduleand yield maximal submodules whose entire subtree has a unified dtype.include_buffers: include buffers in dtype unification checks.
tensor_pred: predicate to choose which tensors to consider (default: all). Example: tensor_pred=torch.is_floating_point (to consider only FP tensors)
return_paths: if True, yields (qualified_name, module, dtype); else (module, dtype).
Notes:
If a module subtree has no tensors passing
tensor_pred, it is ignored.Maximality ensures no yielded module is a strict child of another yielded module.
- nemo_automodel.components.distributed.parallelizer_utils._group_params_by_dtype(
- layer: torch.nn.Module,
- nemo_automodel.components.distributed.parallelizer_utils._get_module_from_path(
- layer: torch.nn.Module,
- path: str,
- nemo_automodel.components.distributed.parallelizer_utils._fully_shard(
- module: torch.nn.Module,
- mesh: torch.distributed.device_mesh.DeviceMesh,
- mp_policy: Optional[torch.distributed.fsdp.MixedPrecisionPolicy],
- offload_policy: Optional[torch.distributed.fsdp.OffloadPolicy],
- nemo_automodel.components.distributed.parallelizer_utils._mp_policy_with_param_dtype(
- mp_policy: Optional[torch.distributed.fsdp.MixedPrecisionPolicy],
- param_dtype: torch.dtype,
- nemo_automodel.components.distributed.parallelizer_utils.fully_shard_by_dtype(
- module: torch.nn.Module,
- mesh: torch.distributed.device_mesh.DeviceMesh,
- mp_policy: Optional[torch.distributed.fsdp.MixedPrecisionPolicy],
- offload_policy: Optional[torch.distributed.fsdp.OffloadPolicy],
Fully shard a module, splitting mixed-dtype subtrees when needed.