bridge.utils.common_utils#

Module Contents#

Functions#

get_rank_safe

Get the distributed rank safely, even if torch.distributed is not initialized.

get_world_size_safe

Get the distributed world size safely, even if torch.distributed is not initialized.

get_last_rank

Get the last rank in the distributed group

get_local_rank_preinit

Get the local rank from the environment variable, intended for use before full init.

print_rank_0

Print a message only on global rank 0.

is_last_rank

Check if the current rank is the last rank in the default process group.

print_rank_last

Print a message only on the last rank of the default process group.

unwrap_model

Recursively unwraps a model or list of models from common wrapper modules.

use_dist_ckpt

Check if the checkpoint format indicates a distributed checkpoint.

API#

bridge.utils.common_utils.get_rank_safe() int#

Get the distributed rank safely, even if torch.distributed is not initialized.

Returns:

The current process rank.

bridge.utils.common_utils.get_world_size_safe() int#

Get the distributed world size safely, even if torch.distributed is not initialized.

Returns:

The total number of processes in the distributed job.

bridge.utils.common_utils.get_last_rank() int#

Get the last rank in the distributed group

bridge.utils.common_utils.get_local_rank_preinit() int#

Get the local rank from the environment variable, intended for use before full init.

Returns:

The local rank of the current process.

bridge.utils.common_utils.print_rank_0(message: str) None#

Print a message only on global rank 0.

Parameters:

message – The message string to print.

bridge.utils.common_utils.is_last_rank() bool#

Check if the current rank is the last rank in the default process group.

Returns:

True if the current rank is the last one, False otherwise.

bridge.utils.common_utils.print_rank_last(message: str) None#

Print a message only on the last rank of the default process group.

Parameters:

message – The message string to print.

bridge.utils.common_utils.unwrap_model(
model: Union[torch.nn.Module, list[torch.nn.Module]],
module_instances: tuple[Type[torch.nn.Module], ...] = ALL_MODULE_WRAPPER_CLASSNAMES,
) Union[torch.nn.Module, list[torch.nn.Module]]#

Recursively unwraps a model or list of models from common wrapper modules.

Parameters:
  • model – The model or list of models to unwrap.

  • module_instances – A tuple of wrapper module types to remove (e.g., DDP, Float16Module).

Returns:

The unwrapped model or list of models.

bridge.utils.common_utils.use_dist_ckpt(ckpt_format: str) bool#

Check if the checkpoint format indicates a distributed checkpoint.

Parameters:

ckpt_format – The checkpoint format string (e.g., “torch”, “torch_dist”).

Returns:

True if the format is not “torch”, False otherwise.