nemo_automodel.components.distributed.parallelizer_utils#

Module Contents#

Functions#

iter_maximal_uniform_dtype_subtrees

Traverse module and yield maximal submodules whose entire subtree has a unified dtype.

_group_params_by_dtype

_get_module_from_path

_fully_shard

fully_shard_by_dtype

Data#

API#

nemo_automodel.components.distributed.parallelizer_utils.UniformSubtreeItem#

None

nemo_automodel.components.distributed.parallelizer_utils.iter_maximal_uniform_dtype_subtrees(
module: torch.nn.Module,
*,
include_buffers: bool = True,
tensor_pred: Optional[Callable[[torch.Tensor], bool]] = None,
return_paths: bool = False,
) Iterator[nemo_automodel.components.distributed.parallelizer_utils.UniformSubtreeItem]#

Traverse module and yield maximal submodules whose entire subtree has a unified dtype.

  • include_buffers: include buffers in dtype unification checks.

  • tensor_pred: predicate to choose which tensors to consider (default: all). Example: tensor_pred=torch.is_floating_point (to consider only FP tensors)

  • return_paths: if True, yields (qualified_name, module, dtype); else (module, dtype).

Notes:

  • If a module subtree has no tensors passing tensor_pred, it is ignored.

  • Maximality ensures no yielded module is a strict child of another yielded module.

nemo_automodel.components.distributed.parallelizer_utils._group_params_by_dtype(layer)#
nemo_automodel.components.distributed.parallelizer_utils._get_module_from_path(layer, path)#
nemo_automodel.components.distributed.parallelizer_utils._fully_shard(module, mesh, mp_policy, offload_policy)#
nemo_automodel.components.distributed.parallelizer_utils.fully_shard_by_dtype(module, mesh, mp_policy, offload_policy)#