bridge.training.callbacks#
Training callbacks for Megatron-Bridge.
This module provides a lightweight callback system for injecting custom logic into the training loop without modifying framework code.
Two registration patterns are supported:
Class-based: Subclass
Callbackand override event methodsclass MyCallback(Callback): def on_train_start(self, context): print("Training started!") pretrain(config, forward_step_func, callbacks=[MyCallback()])
Functional: Register functions directly with
CallbackManagermanager = CallbackManager() manager.register("on_train_step_end", my_logging_fn) pretrain(config, forward_step_func, callbacks=manager)
Both patterns can be mixed. Callbacks fire in registration order.
Module Contents#
Classes#
Context passed to callbacks. |
|
Base class for organizing callbacks. |
|
Manages registration and execution of training callbacks. |
Functions#
Normalize callbacks argument to a CallbackManager. |
|
Check if callbacks should be fired for an event. |
Data#
API#
- bridge.training.callbacks.logger: logging.Logger#
‘getLogger(…)’
- bridge.training.callbacks.VALID_EVENTS: frozenset[str]#
‘frozenset(…)’
- class bridge.training.callbacks.CallbackContext#
Context passed to callbacks.
Contains framework state and a persistent user_state dict. Modifying framework objects is at the user’s own risk.
.. attribute:: state
Global training state (config, train_state, timers, loggers).
.. attribute:: model
List of model chunks.
.. attribute:: user_state
Mutable dict for storing user data across callback invocations.
.. attribute:: optimizer
Optimizer instance. Available during training events only.
.. attribute:: scheduler
Learning rate scheduler. Available during training events only.
.. attribute:: loss_dict
Reduced losses from training step. Available in on_train_step_end.
.. attribute:: grad_norm
Gradient norm. Available in on_train_step_end if computed.
.. attribute:: skipped_iter
Whether the iteration was skipped. Available in on_train_step_end.
.. attribute:: total_loss_dict
Aggregated eval losses. Available in on_eval_end.
Field Availability by Event: All events: state, model, user_state Training events: optimizer, scheduler on_train_step_end: loss_dict, grad_norm, skipped_iter on_eval_end, on_test_end: total_loss_dict
- state: megatron.bridge.training.state.GlobalState#
None
- model: list[megatron.core.transformer.MegatronModule]#
None
- user_state: dict#
‘field(…)’
- optimizer: megatron.core.optimizer.MegatronOptimizer | None#
None
- scheduler: megatron.core.optimizer_param_scheduler.OptimizerParamScheduler | None#
None
- loss_dict: dict[str, torch.Tensor] | None#
None
- grad_norm: float | None#
None
- skipped_iter: bool | None#
None
- total_loss_dict: dict[str, torch.Tensor] | None#
None
- class bridge.training.callbacks.Callback#
Base class for organizing callbacks.
Subclass and override methods for events you want to handle. All methods are no-ops by default.
.. rubric:: Example
class MyCallback(Callback): def on_train_start(self, context): context.user_state['start_time'] = time.time() def on_train_end(self, context): elapsed = time.time() - context.user_state['start_time'] print(f"Training took {elapsed:.2f}s") pretrain(config, forward_step_func, callbacks=[MyCallback()])
- on_train_start( ) None#
Called after model.train(), before training loop begins.
- on_train_step_start( ) None#
Called at the start of each training step.
- on_train_step_end( ) None#
Called after each training step completes.
- on_train_end( ) None#
Called after training loop completes.
- on_eval_start( ) None#
Called after model.eval(), before evaluation loop begins.
- on_eval_step_start( ) None#
Called at the start of each evaluation step.
- on_eval_step_end( ) None#
Called after each evaluation step completes.
- on_eval_end( ) None#
Called after evaluation completes, before model.train().
- on_test_start( ) None#
Called after model.eval(), before test loop begins.
- on_test_step_start( ) None#
Called at the start of each test step.
- on_test_step_end( ) None#
Called after each test step completes.
- on_test_end( ) None#
Called after test completes, before model.train().
- class bridge.training.callbacks.CallbackManager#
Manages registration and execution of training callbacks.
Supports two registration patterns:
Class-based: Use add() with Callback subclass instances
manager.add(MyCallback()) manager.add([CallbackA(), CallbackB()])
Functional: Use register() with event name and callable
manager.register("on_train_start", my_function)
Both patterns can be mixed. Callbacks fire in registration order.
The manager also owns a
user_statedictionary that persists across all callback invocations, allowing callbacks to share state... rubric:: Example
manager = CallbackManager() manager.add(MyCallback()) manager.register("on_eval_end", lambda ctx: print("Eval done!")) pretrain(config, forward_step_func, callbacks=manager)
Initialization
Initialize the callback manager with empty callback lists and user state.
- property user_state: dict#
Mutable dictionary for storing user data across callback invocations.
- add(
- callback: bridge.training.callbacks.Callback | list[bridge.training.callbacks.Callback],
Register one or more Callback instances.
Scans for methods that override the Callback base class and registers them to their corresponding events.
- Parameters:
callback – Single Callback instance or list of Callback instances.
.. rubric:: Example
manager.add(MyCallback()) manager.add([TimingCallback(), LoggingCallback()])
- register(
- event_name: str,
- fn: collections.abc.Callable[[bridge.training.callbacks.CallbackContext], None],
Register a callback function for a specific event.
- Parameters:
event_name –
Event to register for. Valid events:
”on_train_start”
”on_train_step_start”
”on_train_step_end”
”on_train_end”
”on_eval_start”
”on_eval_step_start”
”on_eval_step_end”
”on_eval_end”
”on_test_start”
”on_test_step_start”
”on_test_step_end”
”on_test_end”
fn – Callback function with signature (CallbackContext) -> None.
- Raises:
ValueError – If event_name is not valid.
.. rubric:: Example
manager.register("on_train_step_end", my_logging_fn)
- property events: frozenset[str]#
Set of valid event names for registration.
- list_callbacks(
- event_name: str,
Return list of callbacks registered for an event.
- Parameters:
event_name – Name of the event.
- Returns:
List of registered callables (in execution order).
- Raises:
ValueError – If event_name is not valid.
- has_callbacks(event_name: str) bool#
Check if any callbacks are registered for an event.
- Parameters:
event_name – Name of the event.
- Returns:
True if at least one callback is registered for the event.
- fire(
- event_name: str,
- context: bridge.training.callbacks.CallbackContext,
Execute all callbacks for an event.
Exceptions from callbacks propagate to the caller.
- Parameters:
event_name – Name of the event to fire.
context – CallbackContext to pass to callbacks.
- bridge.training.callbacks.normalize_callbacks(
- callbacks: list[bridge.training.callbacks.Callback] | bridge.training.callbacks.CallbackManager | None,
Normalize callbacks argument to a CallbackManager.
This helper is used internally by pretrain() to accept multiple input formats.
- Parameters:
callbacks – Either a list of Callback instances, a CallbackManager, or None.
- Returns:
A CallbackManager instance, or None if callbacks was None.
- bridge.training.callbacks.should_fire(
- callback_manager: bridge.training.callbacks.CallbackManager | None,
- event_name: str,
Check if callbacks should be fired for an event.
Combines the None check and has_callbacks check into a single call.
- Parameters:
callback_manager – The callback manager instance, or None.
event_name – Name of the event to check.
- Returns:
True if callback_manager exists and has callbacks for the event.