bridge.utils.common_utils#

Module Contents#

Functions#

get_rank_safe

Get the distributed rank safely, even if torch.distributed is not initialized.

get_world_size_safe

Get the distributed world size safely, even if torch.distributed is not initialized.

get_last_rank

Get the last rank in the distributed group

get_local_rank_preinit

Get the local rank from the environment variable, intended for use before full init.

get_master_addr_safe

Get the master address for distributed initialization.

get_master_port_safe

Get the master port for distributed initialization.

print_rank_0

Print a message only on global rank 0.

warn_rank_0

Warn only on rank 0.

is_last_rank

Check if the current rank is the last rank in the default process group.

print_rank_last

Print a message only on the last rank of the default process group.

hook_hf_module_setattr_for_tp_grad_sync

Mark params for TP grad sync and hook setattr on a module and its children.

extract_expert_number_from_param

Extract the expert number from a parameter name.

resolve_path

Resolve a path to an absolute path.

API#

bridge.utils.common_utils.get_rank_safe() int#

Get the distributed rank safely, even if torch.distributed is not initialized.

Fallback order:

  1. torch.distributed.get_rank() (if initialized)

  2. RANK environment variable (torchrun/torchelastic)

  3. SLURM_PROCID environment variable (SLURM)

  4. Default: 0 (with warning)

Returns:

The current process rank.

bridge.utils.common_utils.get_world_size_safe() int#

Get the distributed world size safely, even if torch.distributed is not initialized.

Fallback order:

  1. torch.distributed.get_world_size() (if initialized)

  2. WORLD_SIZE environment variable (torchrun/torchelastic)

  3. SLURM_NTASKS environment variable (SLURM)

  4. Default: 1 (with warning)

Returns:

The total number of processes in the distributed job.

bridge.utils.common_utils.get_last_rank() int#

Get the last rank in the distributed group

bridge.utils.common_utils.get_local_rank_preinit() int#

Get the local rank from the environment variable, intended for use before full init.

Fallback order:

  1. LOCAL_RANK environment variable (torchrun/torchelastic)

  2. SLURM_LOCALID environment variable (SLURM)

  3. Default: 0 (with warning)

Returns:

The local rank of the current process.

bridge.utils.common_utils.get_master_addr_safe() str#

Get the master address for distributed initialization.

Fallback order:

  1. MASTER_ADDR environment variable (torchrun/torchelastic)

  2. SLURM_NODELIST parsed (SLURM)

  3. Default: localhost (with warning)

Returns:

The master node address.

bridge.utils.common_utils.get_master_port_safe() int#

Get the master port for distributed initialization.

Fallback order:

  1. MASTER_PORT environment variable (torchrun/torchelastic)

  2. SLURM job-based port (SLURM_JOB_ID derived)

  3. Default: 29500 (with warning)

Returns:

The master port.

bridge.utils.common_utils.print_rank_0(message: str) None#

Print a message only on global rank 0.

Parameters:

message – The message string to print.

bridge.utils.common_utils.warn_rank_0(message)#

Warn only on rank 0.

bridge.utils.common_utils.is_last_rank() bool#

Check if the current rank is the last rank in the default process group.

Returns:

True if the current rank is the last one, False otherwise.

bridge.utils.common_utils.print_rank_last(message: str) None#

Print a message only on the last rank of the default process group.

Parameters:

message – The message string to print.

bridge.utils.common_utils.hook_hf_module_setattr_for_tp_grad_sync(
module: torch.nn.Module,
) torch.nn.Module#

Mark params for TP grad sync and hook setattr on a module and its children.

This ensures that all existing parameters under the provided module have the attribute average_gradients_across_tp_domain=True and that any future submodules assigned onto this module (or any of its current children) will also have their parameters marked automatically.

Parameters:

module – The root module (typically a Hugging Face module instance).

Returns:

The same module instance for convenience.

bridge.utils.common_utils.extract_expert_number_from_param(param_name: str) int#

Extract the expert number from a parameter name.

Parameters:

param_name – The parameter name to extract the expert number from.

Returns:

The expert number.

bridge.utils.common_utils.resolve_path(path: str) pathlib.Path#

Resolve a path to an absolute path.