bridge.training.setup#

Module Contents#

Classes#

SetupOutput

Represents the output of the main setup function.

Functions#

setup

Initialize the training/evaluation environment using an existing GlobalState.

_update_model_config_funcs

Update model config sync funcs based on initialized model.

_create_peft_pre_wrap_hook

Create a pre-wrap hook that handles PEFT logic.

_apply_peft_transformation

Apply PEFT transformation to the base model.

_validate_and_set_vocab_size

Validate and determine the correct vocab size for the model.

maybe_log_and_save_config

Save configuration to disk and log non-default values on rank 0.

API#

class bridge.training.setup.SetupOutput#

Bases: typing.NamedTuple

Represents the output of the main setup function.

Contains all the initialized components necessary for training or evaluation.

.. attribute:: state

The global state object holding configuration and runtime information.

.. attribute:: model

The initialized Megatron model.

.. attribute:: optimizer

The initialized optimizer.

.. attribute:: scheduler

The initialized learning rate scheduler.

.. attribute:: train_data_iterator

The data iterator for the training dataset, if applicable.

.. attribute:: valid_data_iterator

The data iterator for the validation dataset, if applicable.

.. attribute:: test_data_iterator

The data iterator for the testing dataset, if applicable.

.. attribute:: checkpointing_context

A dictionary holding context for checkpointing operations, especially for non-persistent local checkpointing.

.. attribute:: pg_collection

The process group collection initialized for this run.

state: megatron.bridge.training.state.GlobalState#

None

model: megatron.core.transformer.MegatronModule#

None

optimizer: megatron.core.optimizer.MegatronOptimizer#

None

scheduler: megatron.core.optimizer_param_scheduler.OptimizerParamScheduler#

None

train_data_iterator: Optional[megatron.core.rerun_state_machine.RerunDataIterator | list[megatron.core.rerun_state_machine.RerunDataIterator]]#

None

valid_data_iterator: Optional[megatron.core.rerun_state_machine.RerunDataIterator | list[megatron.core.rerun_state_machine.RerunDataIterator]]#

None

test_data_iterator: Optional[megatron.core.rerun_state_machine.RerunDataIterator | list[megatron.core.rerun_state_machine.RerunDataIterator]]#

None

checkpointing_context: dict[str, Any]#

None

pg_collection: megatron.core.process_groups_config.ProcessGroupCollection#

None

bridge.training.setup.setup(
state: megatron.bridge.training.state.GlobalState,
train_valid_test_datasets_provider: Callable[..., tuple[Optional[Any], Optional[Any], Optional[Any]]],
get_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]] = None,
get_position_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]] = None,
restart_store: Optional[torch.distributed.Store] = None,
) bridge.training.setup.SetupOutput#

Initialize the training/evaluation environment using an existing GlobalState.

Performs all runtime setup using the provided state and its attached config (state.cfg). This includes:

  • enabling Megatron-Core experimental features

  • initializing async checkpoint workers (if enabled)

  • logging setup

  • torch.distributed and model-parallel initialization (via initialize_megatron)

  • tokenizer/model/optimizer/scheduler construction

  • optional checkpoint load

  • dataloader setup

Parameters:
  • state – The GlobalState instance to populate and use throughout setup.

  • train_valid_test_datasets_provider – Callable returning the train/valid/test datasets or iterators.

  • get_embedding_ranks – Optional function to determine embedding layer ranks for model-parallel init.

  • get_position_embedding_ranks – Optional function to determine positional embedding ranks.

  • restart_store – Optional torch.distributed Store used when in-process restart is enabled.

Returns:

SetupOutput containing the populated state, model, optimizer, scheduler, dataloaders, and ckpt context.

bridge.training.setup._update_model_config_funcs(
model: megatron.core.transformer.MegatronModule,
model_config: megatron.bridge.models.GPTModelProvider | megatron.bridge.models.T5ModelProvider,
ddp_config: megatron.core.distributed.DistributedDataParallelConfig,
optimizer: Optional[megatron.core.optimizer.MegatronOptimizer],
*,
align_grad_reduce: bool = True,
pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
) None#

Update model config sync funcs based on initialized model.

bridge.training.setup._create_peft_pre_wrap_hook(
cfg: megatron.bridge.training.config.ConfigContainer,
state: megatron.bridge.training.state.GlobalState,
) Callable[[list[megatron.core.transformer.MegatronModule]], list[megatron.core.transformer.MegatronModule]]#

Create a pre-wrap hook that handles PEFT logic.

This hook is executed before the model is wrapped with DDP/FSDP and handles:

  1. Loading pretrained checkpoints for PEFT

  2. Applying PEFT transformation to the model

Parameters:
  • cfg – Configuration container

  • state – Global state object containing timers and other state

Returns:

A callable hook that can be registered with the model provider

bridge.training.setup._apply_peft_transformation(
peft,
base_model: list[megatron.core.transformer.MegatronModule],
) list[megatron.core.transformer.MegatronModule]#

Apply PEFT transformation to the base model.

Parameters:
  • peft – PEFT configuration/object

  • base_model – Base model before PEFT transformation

Returns:

Model with PEFT transformation applied

bridge.training.setup._validate_and_set_vocab_size(
model_vocab_size: Optional[int],
tokenizer_vocab_size: int,
) tuple[int, bool]#

Validate and determine the correct vocab size for the model.

Parameters:
  • model_vocab_size – Vocab size set in model config (can be None)

  • tokenizer_vocab_size – Unpadded tokenizer vocab size

Returns:

The validated unpadded vocab size and padding flag - vocab_size: The validated unpadded vocab size to use for the model - should_pad_vocab: True if vocab should be padded, False otherwise

Return type:

tuple[int, bool]

Raises:

ValueError – If model vocab size is invalid

bridge.training.setup.maybe_log_and_save_config(
cfg: megatron.bridge.training.config.ConfigContainer,
) None#

Save configuration to disk and log non-default values on rank 0.

Instead of printing the full config YAML, this now logs only the values that differ from Megatron Core defaults, making it easier to spot unintended configuration deviations.

The full config can still be saved to a file via logger.save_config_filepath.