bridge.training.setup#
Module Contents#
Classes#
Represents the output of the main setup function. |
Functions#
Initialize the training/evaluation environment using an existing GlobalState. |
|
Update model config sync funcs based on initialized model. |
|
Create a pre-wrap hook that handles PEFT logic. |
|
Apply PEFT transformation to the base model. |
|
Validate and determine the correct vocab size for the model. |
|
Save configuration to disk and log non-default values on rank 0. |
API#
- class bridge.training.setup.SetupOutput#
Bases:
typing.NamedTupleRepresents the output of the main setup function.
Contains all the initialized components necessary for training or evaluation.
.. attribute:: state
The global state object holding configuration and runtime information.
.. attribute:: model
The initialized Megatron model.
.. attribute:: optimizer
The initialized optimizer.
.. attribute:: scheduler
The initialized learning rate scheduler.
.. attribute:: train_data_iterator
The data iterator for the training dataset, if applicable.
.. attribute:: valid_data_iterator
The data iterator for the validation dataset, if applicable.
.. attribute:: test_data_iterator
The data iterator for the testing dataset, if applicable.
.. attribute:: checkpointing_context
A dictionary holding context for checkpointing operations, especially for non-persistent local checkpointing.
.. attribute:: pg_collection
The process group collection initialized for this run.
- state: megatron.bridge.training.state.GlobalState#
None
- model: megatron.core.transformer.MegatronModule#
None
- optimizer: megatron.core.optimizer.MegatronOptimizer#
None
- scheduler: megatron.core.optimizer_param_scheduler.OptimizerParamScheduler#
None
- train_data_iterator: Optional[megatron.core.rerun_state_machine.RerunDataIterator | list[megatron.core.rerun_state_machine.RerunDataIterator]]#
None
- valid_data_iterator: Optional[megatron.core.rerun_state_machine.RerunDataIterator | list[megatron.core.rerun_state_machine.RerunDataIterator]]#
None
- test_data_iterator: Optional[megatron.core.rerun_state_machine.RerunDataIterator | list[megatron.core.rerun_state_machine.RerunDataIterator]]#
None
- checkpointing_context: dict[str, Any]#
None
- pg_collection: megatron.core.process_groups_config.ProcessGroupCollection#
None
- bridge.training.setup.setup(
- state: megatron.bridge.training.state.GlobalState,
- train_valid_test_datasets_provider: Callable[..., tuple[Optional[Any], Optional[Any], Optional[Any]]],
- get_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]] = None,
- get_position_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]] = None,
- restart_store: Optional[torch.distributed.Store] = None,
Initialize the training/evaluation environment using an existing GlobalState.
Performs all runtime setup using the provided
stateand its attached config (state.cfg). This includes:enabling Megatron-Core experimental features
initializing async checkpoint workers (if enabled)
logging setup
torch.distributed and model-parallel initialization (via initialize_megatron)
tokenizer/model/optimizer/scheduler construction
optional checkpoint load
dataloader setup
- Parameters:
state – The GlobalState instance to populate and use throughout setup.
train_valid_test_datasets_provider – Callable returning the train/valid/test datasets or iterators.
get_embedding_ranks – Optional function to determine embedding layer ranks for model-parallel init.
get_position_embedding_ranks – Optional function to determine positional embedding ranks.
restart_store – Optional torch.distributed Store used when in-process restart is enabled.
- Returns:
SetupOutput containing the populated state, model, optimizer, scheduler, dataloaders, and ckpt context.
- bridge.training.setup._update_model_config_funcs(
- model: megatron.core.transformer.MegatronModule,
- model_config: megatron.bridge.models.GPTModelProvider | megatron.bridge.models.T5ModelProvider,
- ddp_config: megatron.core.distributed.DistributedDataParallelConfig,
- optimizer: Optional[megatron.core.optimizer.MegatronOptimizer],
- *,
- align_grad_reduce: bool = True,
- pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
Update model config sync funcs based on initialized model.
- bridge.training.setup._create_peft_pre_wrap_hook(
- cfg: megatron.bridge.training.config.ConfigContainer,
- state: megatron.bridge.training.state.GlobalState,
Create a pre-wrap hook that handles PEFT logic.
This hook is executed before the model is wrapped with DDP/FSDP and handles:
Loading pretrained checkpoints for PEFT
Applying PEFT transformation to the model
- Parameters:
cfg – Configuration container
state – Global state object containing timers and other state
- Returns:
A callable hook that can be registered with the model provider
- bridge.training.setup._apply_peft_transformation(
- peft,
- base_model: list[megatron.core.transformer.MegatronModule],
Apply PEFT transformation to the base model.
- Parameters:
peft – PEFT configuration/object
base_model – Base model before PEFT transformation
- Returns:
Model with PEFT transformation applied
- bridge.training.setup._validate_and_set_vocab_size(
- model_vocab_size: Optional[int],
- tokenizer_vocab_size: int,
Validate and determine the correct vocab size for the model.
- Parameters:
model_vocab_size – Vocab size set in model config (can be None)
tokenizer_vocab_size – Unpadded tokenizer vocab size
- Returns:
The validated unpadded vocab size and padding flag - vocab_size: The validated unpadded vocab size to use for the model - should_pad_vocab: True if vocab should be padded, False otherwise
- Return type:
tuple[int, bool]
- Raises:
ValueError – If model vocab size is invalid
- bridge.training.setup.maybe_log_and_save_config(
- cfg: megatron.bridge.training.config.ConfigContainer,
Save configuration to disk and log non-default values on rank 0.
Instead of printing the full config YAML, this now logs only the values that differ from Megatron Core defaults, making it easier to spot unintended configuration deviations.
The full config can still be saved to a file via logger.save_config_filepath.