bridge.training.initialize#

Module Contents#

Functions#

initialize_megatron

Initialize Megatron core components and distributed setup.

torch_dist_init

Initialize torch.distributed and dependent components.

init_rerun_state

Initialize the rerun state machine for result validation or stats.

set_jit_fusion_options

Set PyTorch JIT layer fusion options and warmup JIT functions.

destroy_global_state

Destroy Megatron global states.

_initialize_tp_communicators

initializing the communicators with user buffers for high-performance tensor-model-parallel communication overlap

_initialize_distributed

Initialize torch.distributed and core model parallel.

_set_random_seed

Set random seed for reproducability.

_warmup_jit_function

Compilie JIT functions before the main training steps

force_nccl_backend_init

Force NCCL backend initialization for in-process restart compatibility.

API#

bridge.training.initialize.initialize_megatron(
cfg: megatron.bridge.training.config.ConfigContainer,
allow_no_cuda: bool = False,
skip_mpu_initialization: bool = False,
get_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]] = None,
get_position_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]] = None,
restart_store: Optional[torch.distributed.Store] = None,
) Optional[Callable[[], None]]#

Initialize Megatron core components and distributed setup.

Sets up logging, initializes distributed environment (torch.distributed), configures microbatch calculator, and sets random seeds.

Parameters:
  • cfg – The main configuration container.

  • allow_no_cuda – If True, allows initialization without CUDA.

  • skip_mpu_initialization – If True, skips MPU initialization (for external managers).

  • get_embedding_ranks – Optional function to determine embedding layer ranks.

  • get_position_embedding_ranks – Optional function to determine position embedding ranks.

  • restart_store – Optional store for in-process restart.

Returns:

An optional callable to finish MPU initialization if lazy_mpu_init is True, otherwise None.

bridge.training.initialize.torch_dist_init(
model_config: megatron.bridge.models.GPTModelProvider | megatron.bridge.models.T5ModelProvider,
dist_config: megatron.bridge.training.config.DistributedInitConfig,
rng_config: megatron.bridge.training.config.RNGConfig,
micro_batch_size: int,
num_distributed_optimizer_instances: int,
get_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]],
get_position_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]],
skip_mpu_initialization: bool,
restart_store: Optional[torch.distributed.Store] = None,
use_inprocess_restart: bool = False,
) Optional[Callable[[], None]]#

Initialize torch.distributed and dependent components.

Handles the core distributed setup, including process group initialization, MPU (Model Parallel Unit) setup, random seed setting, and optional compilation/warmup steps.

Parameters:
  • model_config – Configuration for the specific model (GPTConfig or T5Config).

  • dist_config – Configuration for distributed initialization settings.

  • rng_config – Configuration for random number generation.

  • micro_batch_size – The micro batch size for JIT warmup.

  • num_distributed_optimizer_instances – Number of parallel optimizer instances.

  • get_embedding_ranks – Optional function to determine embedding layer ranks.

  • get_position_embedding_ranks – Optional function to determine position embedding ranks.

  • skip_mpu_initialization – If True, returns a function to finish MPU setup later.

Returns:

An optional callable to finish MPU initialization if skip_mpu_initialization or lazy_mpu_init is True, otherwise None.

bridge.training.initialize.init_rerun_state(
rerun_state_machine_config: megatron.bridge.training.config.RerunStateMachineConfig,
) None#

Initialize the rerun state machine for result validation or stats.

Sets up state saving and restoration functions, particularly for RNG trackers.

Parameters:

rerun_state_machine_config – Configuration for the rerun state machine.

bridge.training.initialize.set_jit_fusion_options(
model_config: megatron.bridge.models.GPTModelProvider | megatron.bridge.models.T5ModelProvider,
micro_batch_size: int,
) None#

Set PyTorch JIT layer fusion options and warmup JIT functions.

Configures the JIT fuser (nvFuser or legacy) based on the PyTorch version and warms up common fused kernels like bias_gelu and bias_dropout_add.

Parameters:
  • model_config – Configuration for the specific model (GPTConfig or T5Config).

  • micro_batch_size – The micro batch size used for warmup tensor shapes.

bridge.training.initialize.destroy_global_state() None#

Destroy Megatron global states.

Cleans up resources used by microbatch calculator, global memory buffer, model parallel groups, and the rerun state machine.

bridge.training.initialize._initialize_tp_communicators(
model_config: megatron.bridge.models.GPTModelProvider | megatron.bridge.models.T5ModelProvider,
micro_batch_size: int,
) None#

initializing the communicators with user buffers for high-performance tensor-model-parallel communication overlap

bridge.training.initialize._initialize_distributed(
model_config: megatron.bridge.models.GPTModelProvider | megatron.bridge.models.T5ModelProvider,
dist_config: megatron.bridge.training.config.DistributedInitConfig,
num_distributed_optimizer_instances: int,
get_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]],
get_position_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]],
restart_store: Optional[torch.distributed.Store] = None,
use_inprocess_restart: bool = False,
) None#

Initialize torch.distributed and core model parallel.

bridge.training.initialize._set_random_seed(
seed_: int,
data_parallel_random_init: bool = False,
te_rng_tracker: bool = False,
inference_rng_tracker: bool = False,
use_cudagraphable_rng: bool = False,
) None#

Set random seed for reproducability.

bridge.training.initialize._warmup_jit_function(
model_config: megatron.bridge.models.GPTModelProvider | megatron.bridge.models.T5ModelProvider,
micro_batch_size: int,
) None#

Compilie JIT functions before the main training steps

bridge.training.initialize.force_nccl_backend_init(device_id: torch.device) None#

Force NCCL backend initialization for in-process restart compatibility.

The nvidia-resiliency-ext in-process restart uses destroy_process_group to terminate the NCCL backend, which does not terminate NCCL kernels if the NCCL backend wasn’t fully initialized before additional distributed subgroups are created.

This function forces full initialization of the NCCL backend by performing a simple all_reduce operation.

Parameters:

device_id – CUDA device ID to use for the dummy tensor operation