bridge.training.initialize
#
Module Contents#
Functions#
Initialize Megatron core components and distributed setup. |
|
Initialize torch.distributed and dependent components. |
|
Initialize the rerun state machine for result validation or stats. |
|
Set PyTorch JIT layer fusion options and warmup JIT functions. |
|
Destroy Megatron global states. |
|
initializing the communicators with user buffers for high-performance tensor-model-parallel communication overlap |
|
Initialize torch.distributed and core model parallel. |
|
Set random seed for reproducability. |
|
Compilie JIT functions before the main training steps |
API#
- bridge.training.initialize.initialize_megatron(
- cfg: megatron.bridge.training.config.ConfigContainer,
- allow_no_cuda: bool = False,
- skip_mpu_initialization: bool = False,
- get_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]] = None,
- get_position_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]] = None,
Initialize Megatron core components and distributed setup.
Sets up logging, initializes distributed environment (torch.distributed), configures microbatch calculator, and sets random seeds.
- Parameters:
cfg β The main configuration container.
allow_no_cuda β If True, allows initialization without CUDA.
skip_mpu_initialization β If True, skips MPU initialization (for external managers).
get_embedding_ranks β Optional function to determine embedding layer ranks.
get_position_embedding_ranks β Optional function to determine position embedding ranks.
- Returns:
An optional callable to finish MPU initialization if lazy_mpu_init is True, otherwise None.
- bridge.training.initialize.torch_dist_init(
- model_config: megatron.bridge.models.GPTModelProvider | megatron.bridge.models.T5ModelProvider,
- dist_config: megatron.bridge.training.config.DistributedInitConfig,
- rng_config: megatron.bridge.training.config.RNGConfig,
- micro_batch_size: int,
- num_distributed_optimizer_instances: int,
- get_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]],
- get_position_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]],
- skip_mpu_initialization: bool,
Initialize torch.distributed and dependent components.
Handles the core distributed setup, including process group initialization, MPU (Model Parallel Unit) setup, random seed setting, and optional compilation/warmup steps.
- Parameters:
model_config β Configuration for the specific model (GPTConfig or T5Config).
dist_config β Configuration for distributed initialization settings.
rng_config β Configuration for random number generation.
micro_batch_size β The micro batch size for JIT warmup.
num_distributed_optimizer_instances β Number of parallel optimizer instances.
get_embedding_ranks β Optional function to determine embedding layer ranks.
get_position_embedding_ranks β Optional function to determine position embedding ranks.
skip_mpu_initialization β If True, returns a function to finish MPU setup later.
- Returns:
An optional callable to finish MPU initialization if skip_mpu_initialization or lazy_mpu_init is True, otherwise None.
- bridge.training.initialize.init_rerun_state(
- rerun_state_machine_config: megatron.bridge.training.config.RerunStateMachineConfig,
Initialize the rerun state machine for result validation or stats.
Sets up state saving and restoration functions, particularly for RNG trackers.
- Parameters:
rerun_state_machine_config β Configuration for the rerun state machine.
- bridge.training.initialize.set_jit_fusion_options(
- model_config: megatron.bridge.models.GPTModelProvider | megatron.bridge.models.T5ModelProvider,
- micro_batch_size: int,
Set PyTorch JIT layer fusion options and warmup JIT functions.
Configures the JIT fuser (nvFuser or legacy) based on the PyTorch version and warms up common fused kernels like bias_gelu and bias_dropout_add.
- Parameters:
model_config β Configuration for the specific model (GPTConfig or T5Config).
micro_batch_size β The micro batch size used for warmup tensor shapes.
- bridge.training.initialize.destroy_global_state() None #
Destroy Megatron global states.
Cleans up resources used by microbatch calculator, global memory buffer, model parallel groups, and the rerun state machine.
- bridge.training.initialize._initialize_tp_communicators(
- model_config: megatron.bridge.models.GPTModelProvider | megatron.bridge.models.T5ModelProvider,
- micro_batch_size: int,
initializing the communicators with user buffers for high-performance tensor-model-parallel communication overlap
- bridge.training.initialize._initialize_distributed(
- model_config: megatron.bridge.models.GPTModelProvider | megatron.bridge.models.T5ModelProvider,
- dist_config: megatron.bridge.training.config.DistributedInitConfig,
- num_distributed_optimizer_instances: int,
- get_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]],
- get_position_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]],
Initialize torch.distributed and core model parallel.
- bridge.training.initialize._set_random_seed(
- seed_: int,
- data_parallel_random_init: bool = False,
- te_rng_tracker: bool = False,
- inference_rng_tracker: bool = False,
Set random seed for reproducability.
- bridge.training.initialize._warmup_jit_function(
- model_config: megatron.bridge.models.GPTModelProvider | megatron.bridge.models.T5ModelProvider,
- micro_batch_size: int,
Compilie JIT functions before the main training steps