bridge.utils.common_utils
#
Module Contents#
Functions#
Get the distributed rank safely, even if torch.distributed is not initialized. |
|
Get the distributed world size safely, even if torch.distributed is not initialized. |
|
Get the last rank in the distributed group |
|
Get the local rank from the environment variable, intended for use before full init. |
|
Print a message only on global rank 0. |
|
Check if the current rank is the last rank in the default process group. |
|
Print a message only on the last rank of the default process group. |
|
Recursively unwraps a model or list of models from common wrapper modules. |
|
Check if the checkpoint format indicates a distributed checkpoint. |
API#
- bridge.utils.common_utils.get_rank_safe() int #
Get the distributed rank safely, even if torch.distributed is not initialized.
- Returns:
The current process rank.
- bridge.utils.common_utils.get_world_size_safe() int #
Get the distributed world size safely, even if torch.distributed is not initialized.
- Returns:
The total number of processes in the distributed job.
- bridge.utils.common_utils.get_last_rank() int #
Get the last rank in the distributed group
- bridge.utils.common_utils.get_local_rank_preinit() int #
Get the local rank from the environment variable, intended for use before full init.
- Returns:
The local rank of the current process.
- bridge.utils.common_utils.print_rank_0(message: str) None #
Print a message only on global rank 0.
- Parameters:
message – The message string to print.
- bridge.utils.common_utils.is_last_rank() bool #
Check if the current rank is the last rank in the default process group.
- Returns:
True if the current rank is the last one, False otherwise.
- bridge.utils.common_utils.print_rank_last(message: str) None #
Print a message only on the last rank of the default process group.
- Parameters:
message – The message string to print.
- bridge.utils.common_utils.unwrap_model(
- model: Union[torch.nn.Module, list[torch.nn.Module]],
- module_instances: tuple[Type[torch.nn.Module], ...] = ALL_MODULE_WRAPPER_CLASSNAMES,
Recursively unwraps a model or list of models from common wrapper modules.
- Parameters:
model – The model or list of models to unwrap.
module_instances – A tuple of wrapper module types to remove (e.g., DDP, Float16Module).
- Returns:
The unwrapped model or list of models.
- bridge.utils.common_utils.use_dist_ckpt(ckpt_format: str) bool #
Check if the checkpoint format indicates a distributed checkpoint.
- Parameters:
ckpt_format – The checkpoint format string (e.g., “torch”, “torch_dist”).
- Returns:
True if the format is not “torch”, False otherwise.