bridge.training.setup
#
Module Contents#
Classes#
Represents the output of the main setup function. |
Functions#
Initializes the training/evaluation environment. |
|
Update model config sync funcs based on initialized model. |
|
Create a pre-wrap hook that handles PEFT logic. |
|
Apply PEFT transformation to the base model. |
API#
- class bridge.training.setup.SetupOutput#
Bases:
typing.NamedTuple
Represents the output of the main setup function.
Contains all the initialized components necessary for training or evaluation.
.. attribute:: state
The global state object holding configuration and runtime information.
.. attribute:: model
The initialized Megatron model.
.. attribute:: optimizer
The initialized optimizer.
.. attribute:: scheduler
The initialized learning rate scheduler.
.. attribute:: train_data_iterator
The data iterator for the training dataset, if applicable.
.. attribute:: valid_data_iterator
The data iterator for the validation dataset, if applicable.
.. attribute:: test_data_iterator
The data iterator for the testing dataset, if applicable.
.. attribute:: checkpointing_context
A dictionary holding context for checkpointing operations, especially for non-persistent local checkpointing.
- state: megatron.bridge.training.state.GlobalState#
None
- model: megatron.core.transformer.MegatronModule#
None
- optimizer: megatron.core.optimizer.MegatronOptimizer#
None
- scheduler: megatron.core.optimizer_param_scheduler.OptimizerParamScheduler#
None
- train_data_iterator: Optional[megatron.core.rerun_state_machine.RerunDataIterator | list[megatron.core.rerun_state_machine.RerunDataIterator]]#
None
- valid_data_iterator: Optional[megatron.core.rerun_state_machine.RerunDataIterator | list[megatron.core.rerun_state_machine.RerunDataIterator]]#
None
- test_data_iterator: Optional[megatron.core.rerun_state_machine.RerunDataIterator | list[megatron.core.rerun_state_machine.RerunDataIterator]]#
None
- checkpointing_context: dict[str, Any]#
None
- bridge.training.setup.setup(
- cfg: megatron.bridge.training.config.ConfigContainer,
- train_valid_test_datasets_provider: Callable[..., tuple[Optional[Any], Optional[Any], Optional[Any]]],
- get_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]] = None,
- get_position_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]] = None,
Initializes the training/evaluation environment.
Sets up logging, initializes Megatron core components (distributed, timers), builds the tokenizer, creates the model, optimizer, and scheduler, loads checkpoints if specified, and prepares data iterators.
- Parameters:
cfg β The main configuration container holding all sub-configurations (model, training, optimizer, etc.).
train_valid_test_datasets_provider β A callable function that takes configuration and potentially a tokenizer, and returns tuples representing the training, validation, and test datasets.
get_embedding_ranks β Optional callable to determine ranks for embedding layers, used during Megatron initialization.
get_position_embedding_ranks β Optional callable to determine ranks for position embedding layers, used during Megatron initialization.
- Returns:
A SetupOutput named tuple containing the initialized state, model, optimizer, scheduler, data iterators, and checkpointing context.
- bridge.training.setup._update_model_config_funcs(
- model: megatron.core.transformer.MegatronModule,
- model_config: megatron.bridge.models.GPTModelProvider | megatron.bridge.models.T5ModelProvider,
- ddp_config: megatron.core.distributed.DistributedDataParallelConfig,
- optimizer: Optional[megatron.core.optimizer.MegatronOptimizer],
- *,
- align_grad_reduce: bool = True,
Update model config sync funcs based on initialized model.
- bridge.training.setup._create_peft_pre_wrap_hook(
- cfg: megatron.bridge.training.config.ConfigContainer,
- state: megatron.bridge.training.state.GlobalState,
Create a pre-wrap hook that handles PEFT logic.
This hook is executed before the model is wrapped with DDP/FSDP and handles:
Loading pretrained checkpoints for PEFT
Applying PEFT transformation to the model
- Parameters:
cfg β Configuration container
state β Global state object containing timers and other state
- Returns:
A callable hook that can be registered with the model provider
- bridge.training.setup._apply_peft_transformation(
- peft,
- base_model: list[megatron.core.transformer.MegatronModule],
Apply PEFT transformation to the base model.
- Parameters:
peft β PEFT configuration/object
base_model β Base model before PEFT transformation
- Returns:
Model with PEFT transformation applied