bridge.training.setup#

Module Contents#

Classes#

SetupOutput

Represents the output of the main setup function.

Functions#

setup

Initialize the training/evaluation environment using an existing GlobalState.

_update_model_config_funcs

Update model config sync funcs based on initialized model.

_create_peft_pre_wrap_hook

Create a pre-wrap hook that handles PEFT logic.

_apply_peft_transformation

Apply PEFT transformation to the base model.

_validate_and_set_vocab_size

Validate and determine the correct vocab size for the model.

API#

class bridge.training.setup.SetupOutput#

Bases: typing.NamedTuple

Represents the output of the main setup function.

Contains all the initialized components necessary for training or evaluation.

.. attribute:: state

The global state object holding configuration and runtime information.

.. attribute:: model

The initialized Megatron model.

.. attribute:: optimizer

The initialized optimizer.

.. attribute:: scheduler

The initialized learning rate scheduler.

.. attribute:: train_data_iterator

The data iterator for the training dataset, if applicable.

.. attribute:: valid_data_iterator

The data iterator for the validation dataset, if applicable.

.. attribute:: test_data_iterator

The data iterator for the testing dataset, if applicable.

.. attribute:: checkpointing_context

A dictionary holding context for checkpointing operations, especially for non-persistent local checkpointing.

state: megatron.bridge.training.state.GlobalState#

None

model: megatron.core.transformer.MegatronModule#

None

optimizer: megatron.core.optimizer.MegatronOptimizer#

None

scheduler: megatron.core.optimizer_param_scheduler.OptimizerParamScheduler#

None

train_data_iterator: Optional[megatron.core.rerun_state_machine.RerunDataIterator | list[megatron.core.rerun_state_machine.RerunDataIterator]]#

None

valid_data_iterator: Optional[megatron.core.rerun_state_machine.RerunDataIterator | list[megatron.core.rerun_state_machine.RerunDataIterator]]#

None

test_data_iterator: Optional[megatron.core.rerun_state_machine.RerunDataIterator | list[megatron.core.rerun_state_machine.RerunDataIterator]]#

None

checkpointing_context: dict[str, Any]#

None

bridge.training.setup.setup(
state: megatron.bridge.training.state.GlobalState,
train_valid_test_datasets_provider: Callable[..., tuple[Optional[Any], Optional[Any], Optional[Any]]],
get_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]] = None,
get_position_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]] = None,
restart_store: Optional[torch.distributed.Store] = None,
) bridge.training.setup.SetupOutput#

Initialize the training/evaluation environment using an existing GlobalState.

Performs all runtime setup using the provided state and its attached config (state.cfg). This includes:

  • enabling Megatron-Core experimental features

  • initializing async checkpoint workers (if enabled)

  • logging setup

  • torch.distributed and model-parallel initialization (via initialize_megatron)

  • tokenizer/model/optimizer/scheduler construction

  • optional checkpoint load

  • dataloader setup

Parameters:
  • state – The GlobalState instance to populate and use throughout setup.

  • train_valid_test_datasets_provider – Callable returning the train/valid/test datasets or iterators.

  • get_embedding_ranks – Optional function to determine embedding layer ranks for model-parallel init.

  • get_position_embedding_ranks – Optional function to determine positional embedding ranks.

  • restart_store – Optional torch.distributed Store used when in-process restart is enabled.

Returns:

SetupOutput containing the populated state, model, optimizer, scheduler, dataloaders, and ckpt context.

bridge.training.setup._update_model_config_funcs(
model: megatron.core.transformer.MegatronModule,
model_config: megatron.bridge.models.GPTModelProvider | megatron.bridge.models.T5ModelProvider,
ddp_config: megatron.core.distributed.DistributedDataParallelConfig,
optimizer: Optional[megatron.core.optimizer.MegatronOptimizer],
*,
align_grad_reduce: bool = True,
) None#

Update model config sync funcs based on initialized model.

bridge.training.setup._create_peft_pre_wrap_hook(
cfg: megatron.bridge.training.config.ConfigContainer,
state: megatron.bridge.training.state.GlobalState,
) Callable[[list[megatron.core.transformer.MegatronModule]], list[megatron.core.transformer.MegatronModule]]#

Create a pre-wrap hook that handles PEFT logic.

This hook is executed before the model is wrapped with DDP/FSDP and handles:

  1. Loading pretrained checkpoints for PEFT

  2. Applying PEFT transformation to the model

Parameters:
  • cfg – Configuration container

  • state – Global state object containing timers and other state

Returns:

A callable hook that can be registered with the model provider

bridge.training.setup._apply_peft_transformation(
peft,
base_model: list[megatron.core.transformer.MegatronModule],
) list[megatron.core.transformer.MegatronModule]#

Apply PEFT transformation to the base model.

Parameters:
  • peft – PEFT configuration/object

  • base_model – Base model before PEFT transformation

Returns:

Model with PEFT transformation applied

bridge.training.setup._validate_and_set_vocab_size(
model_vocab_size: Optional[int],
tokenizer_vocab_size: int,
) tuple[int, bool]#

Validate and determine the correct vocab size for the model.

Parameters:
  • model_vocab_size – Vocab size set in model config (can be None)

  • tokenizer_vocab_size – Unpadded tokenizer vocab size

Returns:

The validated unpadded vocab size and padding flag - vocab_size: The validated unpadded vocab size to use for the model - should_pad_vocab: True if vocab should be padded, False otherwise

Return type:

tuple[int, bool]

Raises:

ValueError – If model vocab size is invalid