bridge.training.initialize#

Module Contents#

Functions#

initialize_megatron

Initialize Megatron core components and distributed setup.

torch_dist_init

Initialize torch.distributed and dependent components.

init_rerun_state

Initialize the rerun state machine for result validation or stats.

set_jit_fusion_options

Set PyTorch JIT layer fusion options and warmup JIT functions.

destroy_global_state

Destroy Megatron global states.

_initialize_tp_communicators

initializing the communicators with user buffers for high-performance tensor-model-parallel communication overlap

_initialize_distributed

Initialize torch.distributed and core model parallel.

_set_random_seed

Set random seed for reproducability.

_warmup_jit_function

Compilie JIT functions before the main training steps

API#

bridge.training.initialize.initialize_megatron(
cfg: megatron.bridge.training.config.ConfigContainer,
allow_no_cuda: bool = False,
skip_mpu_initialization: bool = False,
get_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]] = None,
get_position_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]] = None,
) Optional[Callable[[], None]]#

Initialize Megatron core components and distributed setup.

Sets up logging, initializes distributed environment (torch.distributed), configures microbatch calculator, and sets random seeds.

Parameters:
  • cfg – The main configuration container.

  • allow_no_cuda – If True, allows initialization without CUDA.

  • skip_mpu_initialization – If True, skips MPU initialization (for external managers).

  • get_embedding_ranks – Optional function to determine embedding layer ranks.

  • get_position_embedding_ranks – Optional function to determine position embedding ranks.

Returns:

An optional callable to finish MPU initialization if lazy_mpu_init is True, otherwise None.

bridge.training.initialize.torch_dist_init(
model_config: megatron.bridge.models.GPTModelProvider | megatron.bridge.models.T5ModelProvider,
dist_config: megatron.bridge.training.config.DistributedInitConfig,
rng_config: megatron.bridge.training.config.RNGConfig,
micro_batch_size: int,
num_distributed_optimizer_instances: int,
get_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]],
get_position_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]],
skip_mpu_initialization: bool,
) Optional[Callable[[], None]]#

Initialize torch.distributed and dependent components.

Handles the core distributed setup, including process group initialization, MPU (Model Parallel Unit) setup, random seed setting, and optional compilation/warmup steps.

Parameters:
  • model_config – Configuration for the specific model (GPTConfig or T5Config).

  • dist_config – Configuration for distributed initialization settings.

  • rng_config – Configuration for random number generation.

  • micro_batch_size – The micro batch size for JIT warmup.

  • num_distributed_optimizer_instances – Number of parallel optimizer instances.

  • get_embedding_ranks – Optional function to determine embedding layer ranks.

  • get_position_embedding_ranks – Optional function to determine position embedding ranks.

  • skip_mpu_initialization – If True, returns a function to finish MPU setup later.

Returns:

An optional callable to finish MPU initialization if skip_mpu_initialization or lazy_mpu_init is True, otherwise None.

bridge.training.initialize.init_rerun_state(
rerun_state_machine_config: megatron.bridge.training.config.RerunStateMachineConfig,
) None#

Initialize the rerun state machine for result validation or stats.

Sets up state saving and restoration functions, particularly for RNG trackers.

Parameters:

rerun_state_machine_config – Configuration for the rerun state machine.

bridge.training.initialize.set_jit_fusion_options(
model_config: megatron.bridge.models.GPTModelProvider | megatron.bridge.models.T5ModelProvider,
micro_batch_size: int,
) None#

Set PyTorch JIT layer fusion options and warmup JIT functions.

Configures the JIT fuser (nvFuser or legacy) based on the PyTorch version and warms up common fused kernels like bias_gelu and bias_dropout_add.

Parameters:
  • model_config – Configuration for the specific model (GPTConfig or T5Config).

  • micro_batch_size – The micro batch size used for warmup tensor shapes.

bridge.training.initialize.destroy_global_state() None#

Destroy Megatron global states.

Cleans up resources used by microbatch calculator, global memory buffer, model parallel groups, and the rerun state machine.

bridge.training.initialize._initialize_tp_communicators(
model_config: megatron.bridge.models.GPTModelProvider | megatron.bridge.models.T5ModelProvider,
micro_batch_size: int,
) None#

initializing the communicators with user buffers for high-performance tensor-model-parallel communication overlap

bridge.training.initialize._initialize_distributed(
model_config: megatron.bridge.models.GPTModelProvider | megatron.bridge.models.T5ModelProvider,
dist_config: megatron.bridge.training.config.DistributedInitConfig,
num_distributed_optimizer_instances: int,
get_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]],
get_position_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]],
) None#

Initialize torch.distributed and core model parallel.

bridge.training.initialize._set_random_seed(
seed_: int,
data_parallel_random_init: bool = False,
te_rng_tracker: bool = False,
inference_rng_tracker: bool = False,
) None#

Set random seed for reproducability.

bridge.training.initialize._warmup_jit_function(
model_config: megatron.bridge.models.GPTModelProvider | megatron.bridge.models.T5ModelProvider,
micro_batch_size: int,
) None#

Compilie JIT functions before the main training steps