bridge.training.pretrain#

Module Contents#

Functions#

pretrain

Main function to run the training pipeline.

_pretrain

Internal function containing the actual pretrain logic.

API#

bridge.training.pretrain.pretrain(
config: megatron.bridge.training.config.ConfigContainer,
forward_step_func: megatron.bridge.training.forward_step_func_types.ForwardStepCallable,
) None#

Main function to run the training pipeline.

Sets up the environment, model, optimizer, scheduler, and data iterators. Performs training, validation, and optionally testing based on the provided configuration.

Parameters:
  • config – The main configuration container holding all necessary parameters.

  • forward_step_func –

    A callable (function or functor) that performs a single forward and backward step, returning the loss and any computed metrics. Supports the following signatures:

    • 2 args: (data_iterator, model)

    • 3 args: (data_iterator, model, return_schedule_plan=False) OR (state: GlobalState, data_iterator, model)

    • 4 args: (state: GlobalState, data_iterator, model, return_schedule_plan=False)

.. note::

Use the signature with GlobalState type hint for full access to configuration, timers, and training state. State injection is automatic based on type hints or parameter names. Functors (classes with call) are fully supported.

.. warning::

This is an experimental API and is subject to change in backwards incompatible ways without notice.

bridge.training.pretrain._pretrain(
state: megatron.bridge.training.state.GlobalState,
forward_step_func: megatron.bridge.training.forward_step_func_types.ForwardStepCallable,
store: Optional[torch.distributed.Store] = None,
inprocess_call_wrapper: Optional[nvidia_resiliency_ext.inprocess.CallWrapper] = None,
) None#

Internal function containing the actual pretrain logic.

Parameters:
  • state – Global training state containing the validated configuration and runtime objects

  • forward_step_func – Function or functor that performs a single forward/backward step

  • store – Optional distributed Store used by in-process restart for coordination

  • inprocess_call_wrapper – Optional wrapper injected by nvrx to expose restart iteration