nemo_automodel.components.distributed.parallelizer_utils#

Module Contents#

Functions#

iter_maximal_uniform_dtype_subtrees

Traverse module and yield maximal submodules whose entire subtree has a unified dtype.

_group_params_by_dtype

_get_module_from_path

_fully_shard

_mp_policy_with_param_dtype

fully_shard_by_dtype

Fully shard a module, splitting mixed-dtype subtrees when needed.

Data#

API#

nemo_automodel.components.distributed.parallelizer_utils.UniformSubtreeItem#

None

nemo_automodel.components.distributed.parallelizer_utils.iter_maximal_uniform_dtype_subtrees(
module: torch.nn.Module,
*,
include_buffers: bool = True,
tensor_pred: Optional[Callable[[torch.Tensor], bool]] = None,
return_paths: bool = False,
) Iterator[nemo_automodel.components.distributed.parallelizer_utils.UniformSubtreeItem]#

Traverse module and yield maximal submodules whose entire subtree has a unified dtype.

  • include_buffers: include buffers in dtype unification checks.

  • tensor_pred: predicate to choose which tensors to consider (default: all). Example: tensor_pred=torch.is_floating_point (to consider only FP tensors)

  • return_paths: if True, yields (qualified_name, module, dtype); else (module, dtype).

Notes:

  • If a module subtree has no tensors passing tensor_pred, it is ignored.

  • Maximality ensures no yielded module is a strict child of another yielded module.

nemo_automodel.components.distributed.parallelizer_utils._group_params_by_dtype(
layer: torch.nn.Module,
) Dict[torch.dtype, List[torch.nn.Parameter]]#
nemo_automodel.components.distributed.parallelizer_utils._get_module_from_path(
layer: torch.nn.Module,
path: str,
) torch.nn.Module#
nemo_automodel.components.distributed.parallelizer_utils._fully_shard(
module: torch.nn.Module,
mesh: torch.distributed.device_mesh.DeviceMesh,
mp_policy: Optional[torch.distributed.fsdp.MixedPrecisionPolicy],
offload_policy: Optional[torch.distributed.fsdp.OffloadPolicy],
) None#
nemo_automodel.components.distributed.parallelizer_utils._mp_policy_with_param_dtype(
mp_policy: Optional[torch.distributed.fsdp.MixedPrecisionPolicy],
param_dtype: torch.dtype,
) Optional[torch.distributed.fsdp.MixedPrecisionPolicy]#
nemo_automodel.components.distributed.parallelizer_utils.fully_shard_by_dtype(
module: torch.nn.Module,
mesh: torch.distributed.device_mesh.DeviceMesh,
mp_policy: Optional[torch.distributed.fsdp.MixedPrecisionPolicy],
offload_policy: Optional[torch.distributed.fsdp.OffloadPolicy],
) None#

Fully shard a module, splitting mixed-dtype subtrees when needed.