bridge.training.setup#

Module Contents#

Classes#

SetupOutput

Represents the output of the main setup function.

Functions#

setup

Initializes the training/evaluation environment.

_update_model_config_funcs

Update model config sync funcs based on initialized model.

_create_peft_pre_wrap_hook

Create a pre-wrap hook that handles PEFT logic.

_apply_peft_transformation

Apply PEFT transformation to the base model.

API#

class bridge.training.setup.SetupOutput#

Bases: typing.NamedTuple

Represents the output of the main setup function.

Contains all the initialized components necessary for training or evaluation.

.. attribute:: state

The global state object holding configuration and runtime information.

.. attribute:: model

The initialized Megatron model.

.. attribute:: optimizer

The initialized optimizer.

.. attribute:: scheduler

The initialized learning rate scheduler.

.. attribute:: train_data_iterator

The data iterator for the training dataset, if applicable.

.. attribute:: valid_data_iterator

The data iterator for the validation dataset, if applicable.

.. attribute:: test_data_iterator

The data iterator for the testing dataset, if applicable.

.. attribute:: checkpointing_context

A dictionary holding context for checkpointing operations, especially for non-persistent local checkpointing.

state: megatron.bridge.training.state.GlobalState#

None

model: megatron.core.transformer.MegatronModule#

None

optimizer: megatron.core.optimizer.MegatronOptimizer#

None

scheduler: megatron.core.optimizer_param_scheduler.OptimizerParamScheduler#

None

train_data_iterator: Optional[megatron.core.rerun_state_machine.RerunDataIterator | list[megatron.core.rerun_state_machine.RerunDataIterator]]#

None

valid_data_iterator: Optional[megatron.core.rerun_state_machine.RerunDataIterator | list[megatron.core.rerun_state_machine.RerunDataIterator]]#

None

test_data_iterator: Optional[megatron.core.rerun_state_machine.RerunDataIterator | list[megatron.core.rerun_state_machine.RerunDataIterator]]#

None

checkpointing_context: dict[str, Any]#

None

bridge.training.setup.setup(
cfg: megatron.bridge.training.config.ConfigContainer,
train_valid_test_datasets_provider: Callable[..., tuple[Optional[Any], Optional[Any], Optional[Any]]],
get_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]] = None,
get_position_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]] = None,
) bridge.training.setup.SetupOutput#

Initializes the training/evaluation environment.

Sets up logging, initializes Megatron core components (distributed, timers), builds the tokenizer, creates the model, optimizer, and scheduler, loads checkpoints if specified, and prepares data iterators.

Parameters:
  • cfg – The main configuration container holding all sub-configurations (model, training, optimizer, etc.).

  • train_valid_test_datasets_provider – A callable function that takes configuration and potentially a tokenizer, and returns tuples representing the training, validation, and test datasets.

  • get_embedding_ranks – Optional callable to determine ranks for embedding layers, used during Megatron initialization.

  • get_position_embedding_ranks – Optional callable to determine ranks for position embedding layers, used during Megatron initialization.

Returns:

A SetupOutput named tuple containing the initialized state, model, optimizer, scheduler, data iterators, and checkpointing context.

bridge.training.setup._update_model_config_funcs(
model: megatron.core.transformer.MegatronModule,
model_config: megatron.bridge.models.GPTModelProvider | megatron.bridge.models.T5ModelProvider,
ddp_config: megatron.core.distributed.DistributedDataParallelConfig,
optimizer: Optional[megatron.core.optimizer.MegatronOptimizer],
*,
align_grad_reduce: bool = True,
) None#

Update model config sync funcs based on initialized model.

bridge.training.setup._create_peft_pre_wrap_hook(
cfg: megatron.bridge.training.config.ConfigContainer,
state: megatron.bridge.training.state.GlobalState,
) Callable[[list[megatron.core.transformer.MegatronModule]], list[megatron.core.transformer.MegatronModule]]#

Create a pre-wrap hook that handles PEFT logic.

This hook is executed before the model is wrapped with DDP/FSDP and handles:

  1. Loading pretrained checkpoints for PEFT

  2. Applying PEFT transformation to the model

Parameters:
  • cfg – Configuration container

  • state – Global state object containing timers and other state

Returns:

A callable hook that can be registered with the model provider

bridge.training.setup._apply_peft_transformation(
peft,
base_model: list[megatron.core.transformer.MegatronModule],
) list[megatron.core.transformer.MegatronModule]#

Apply PEFT transformation to the base model.

Parameters:
  • peft – PEFT configuration/object

  • base_model – Base model before PEFT transformation

Returns:

Model with PEFT transformation applied