bridge.training.pretrain#

Module Contents#

Functions#

pretrain

Main function to run the training pipeline.

_pretrain

Internal function containing the actual pretrain logic.

API#

bridge.training.pretrain.pretrain(
config: megatron.bridge.training.config.ConfigContainer,
forward_step_func: Callable,
) None#

Main function to run the training pipeline.

Sets up the environment, model, optimizer, scheduler, and data iterators. Performs training, validation, and optionally testing based on the provided configuration.

Parameters:
  • config – The main configuration container holding all necessary parameters.

  • forward_step_func – A callable that performs a single forward and backward step, returning the loss and any computed metrics.

.. warning::

This is an experimental API and is subject to change in backwards incompatible ways without notice.

bridge.training.pretrain._pretrain(
state: megatron.bridge.training.state.GlobalState,
forward_step_func: Callable,
store: Optional[torch.distributed.Store] = None,
inprocess_call_wrapper: Optional[nvidia_resiliency_ext.inprocess.CallWrapper] = None,
) None#

Internal function containing the actual pretrain logic.

Parameters:
  • state – Global training state containing the validated configuration and runtime objects

  • forward_step_func – Function that performs a single forward/backward step

  • store – Optional distributed Store used by in-process restart for coordination

  • inprocess_call_wrapper – Optional wrapper injected by nvrx to expose restart iteration