nemo_automodel.components.distributed.parallelizer_utils#
Module Contents#
Functions#
Traverse |
|
Data#
API#
- nemo_automodel.components.distributed.parallelizer_utils.UniformSubtreeItem#
None
- nemo_automodel.components.distributed.parallelizer_utils.iter_maximal_uniform_dtype_subtrees(
- module: torch.nn.Module,
- *,
- include_buffers: bool = True,
- tensor_pred: Optional[Callable[[torch.Tensor], bool]] = None,
- return_paths: bool = False,
Traverse
moduleand yield maximal submodules whose entire subtree has a unified dtype.include_buffers: include buffers in dtype unification checks.
tensor_pred: predicate to choose which tensors to consider (default: all). Example: tensor_pred=torch.is_floating_point (to consider only FP tensors)
return_paths: if True, yields (qualified_name, module, dtype); else (module, dtype).
Notes:
If a module subtree has no tensors passing
tensor_pred, it is ignored.Maximality ensures no yielded module is a strict child of another yielded module.
- nemo_automodel.components.distributed.parallelizer_utils._group_params_by_dtype(layer)#
- nemo_automodel.components.distributed.parallelizer_utils._get_module_from_path(layer, path)#
- nemo_automodel.components.distributed.parallelizer_utils._fully_shard(module, mesh, mp_policy, offload_policy)#
- nemo_automodel.components.distributed.parallelizer_utils.fully_shard_by_dtype(module, mesh, mp_policy, offload_policy)#