bridge.training.setup
#
Module Contents#
Classes#
Represents the output of the main setup function. |
Functions#
Initialize the training/evaluation environment using an existing GlobalState. |
|
Update model config sync funcs based on initialized model. |
|
Create a pre-wrap hook that handles PEFT logic. |
|
Apply PEFT transformation to the base model. |
|
Validate and determine the correct vocab size for the model. |
API#
- class bridge.training.setup.SetupOutput#
Bases:
typing.NamedTuple
Represents the output of the main setup function.
Contains all the initialized components necessary for training or evaluation.
.. attribute:: state
The global state object holding configuration and runtime information.
.. attribute:: model
The initialized Megatron model.
.. attribute:: optimizer
The initialized optimizer.
.. attribute:: scheduler
The initialized learning rate scheduler.
.. attribute:: train_data_iterator
The data iterator for the training dataset, if applicable.
.. attribute:: valid_data_iterator
The data iterator for the validation dataset, if applicable.
.. attribute:: test_data_iterator
The data iterator for the testing dataset, if applicable.
.. attribute:: checkpointing_context
A dictionary holding context for checkpointing operations, especially for non-persistent local checkpointing.
- state: megatron.bridge.training.state.GlobalState#
None
- model: megatron.core.transformer.MegatronModule#
None
- optimizer: megatron.core.optimizer.MegatronOptimizer#
None
- scheduler: megatron.core.optimizer_param_scheduler.OptimizerParamScheduler#
None
- train_data_iterator: Optional[megatron.core.rerun_state_machine.RerunDataIterator | list[megatron.core.rerun_state_machine.RerunDataIterator]]#
None
- valid_data_iterator: Optional[megatron.core.rerun_state_machine.RerunDataIterator | list[megatron.core.rerun_state_machine.RerunDataIterator]]#
None
- test_data_iterator: Optional[megatron.core.rerun_state_machine.RerunDataIterator | list[megatron.core.rerun_state_machine.RerunDataIterator]]#
None
- checkpointing_context: dict[str, Any]#
None
- bridge.training.setup.setup(
- state: megatron.bridge.training.state.GlobalState,
- train_valid_test_datasets_provider: Callable[..., tuple[Optional[Any], Optional[Any], Optional[Any]]],
- get_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]] = None,
- get_position_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]] = None,
- restart_store: Optional[torch.distributed.Store] = None,
Initialize the training/evaluation environment using an existing GlobalState.
Performs all runtime setup using the provided
state
and its attached config (state.cfg
). This includes:enabling Megatron-Core experimental features
initializing async checkpoint workers (if enabled)
logging setup
torch.distributed and model-parallel initialization (via initialize_megatron)
tokenizer/model/optimizer/scheduler construction
optional checkpoint load
dataloader setup
- Parameters:
state – The GlobalState instance to populate and use throughout setup.
train_valid_test_datasets_provider – Callable returning the train/valid/test datasets or iterators.
get_embedding_ranks – Optional function to determine embedding layer ranks for model-parallel init.
get_position_embedding_ranks – Optional function to determine positional embedding ranks.
restart_store – Optional torch.distributed Store used when in-process restart is enabled.
- Returns:
SetupOutput containing the populated state, model, optimizer, scheduler, dataloaders, and ckpt context.
- bridge.training.setup._update_model_config_funcs(
- model: megatron.core.transformer.MegatronModule,
- model_config: megatron.bridge.models.GPTModelProvider | megatron.bridge.models.T5ModelProvider,
- ddp_config: megatron.core.distributed.DistributedDataParallelConfig,
- optimizer: Optional[megatron.core.optimizer.MegatronOptimizer],
- *,
- align_grad_reduce: bool = True,
Update model config sync funcs based on initialized model.
- bridge.training.setup._create_peft_pre_wrap_hook(
- cfg: megatron.bridge.training.config.ConfigContainer,
- state: megatron.bridge.training.state.GlobalState,
Create a pre-wrap hook that handles PEFT logic.
This hook is executed before the model is wrapped with DDP/FSDP and handles:
Loading pretrained checkpoints for PEFT
Applying PEFT transformation to the model
- Parameters:
cfg – Configuration container
state – Global state object containing timers and other state
- Returns:
A callable hook that can be registered with the model provider
- bridge.training.setup._apply_peft_transformation(
- peft,
- base_model: list[megatron.core.transformer.MegatronModule],
Apply PEFT transformation to the base model.
- Parameters:
peft – PEFT configuration/object
base_model – Base model before PEFT transformation
- Returns:
Model with PEFT transformation applied
- bridge.training.setup._validate_and_set_vocab_size(
- model_vocab_size: Optional[int],
- tokenizer_vocab_size: int,
Validate and determine the correct vocab size for the model.
- Parameters:
model_vocab_size – Vocab size set in model config (can be None)
tokenizer_vocab_size – Unpadded tokenizer vocab size
- Returns:
The validated unpadded vocab size and padding flag - vocab_size: The validated unpadded vocab size to use for the model - should_pad_vocab: True if vocab should be padded, False otherwise
- Return type:
tuple[int, bool]
- Raises:
ValueError – If model vocab size is invalid