bridge.training.pretrain
#
Module Contents#
Functions#
API#
- bridge.training.pretrain.pretrain(
- config: megatron.bridge.training.config.ConfigContainer,
- forward_step_func: megatron.bridge.training.forward_step_func_types.ForwardStepCallable,
Main function to run the training pipeline.
Sets up the environment, model, optimizer, scheduler, and data iterators. Performs training, validation, and optionally testing based on the provided configuration.
- Parameters:
config – The main configuration container holding all necessary parameters.
forward_step_func –
A callable (function or functor) that performs a single forward and backward step, returning the loss and any computed metrics. Supports the following signatures:
2 args: (data_iterator, model)
3 args: (data_iterator, model, return_schedule_plan=False) OR (state: GlobalState, data_iterator, model)
4 args: (state: GlobalState, data_iterator, model, return_schedule_plan=False)
.. note::
Use the signature with GlobalState type hint for full access to configuration, timers, and training state. State injection is automatic based on type hints or parameter names. Functors (classes with call) are fully supported.
.. warning::
This is an experimental API and is subject to change in backwards incompatible ways without notice.
- bridge.training.pretrain._pretrain(
- state: megatron.bridge.training.state.GlobalState,
- forward_step_func: megatron.bridge.training.forward_step_func_types.ForwardStepCallable,
- store: Optional[torch.distributed.Store] = None,
- inprocess_call_wrapper: Optional[nvidia_resiliency_ext.inprocess.CallWrapper] = None,
Internal function containing the actual pretrain logic.
- Parameters:
state – Global training state containing the validated configuration and runtime objects
forward_step_func – Function or functor that performs a single forward/backward step
store – Optional distributed Store used by in-process restart for coordination
inprocess_call_wrapper – Optional wrapper injected by nvrx to expose restart iteration