bridge.training.callbacks#

Training callbacks for Megatron-Bridge.

This module provides a lightweight callback system for injecting custom logic into the training loop without modifying framework code.

Two registration patterns are supported:

  1. Class-based: Subclass Callback and override event methods

    class MyCallback(Callback):
        def on_train_start(self, context):
            print("Training started!")
    
    pretrain(config, forward_step_func, callbacks=[MyCallback()])
    
  2. Functional: Register functions directly with CallbackManager

    manager = CallbackManager()
    manager.register("on_train_step_end", my_logging_fn)
    pretrain(config, forward_step_func, callbacks=manager)
    

Both patterns can be mixed. Callbacks fire in registration order.

Module Contents#

Classes#

CallbackContext

Context passed to callbacks.

Callback

Base class for organizing callbacks.

CallbackManager

Manages registration and execution of training callbacks.

Functions#

normalize_callbacks

Normalize callbacks argument to a CallbackManager.

should_fire

Check if callbacks should be fired for an event.

Data#

API#

bridge.training.callbacks.logger: logging.Logger#

‘getLogger(…)’

bridge.training.callbacks.VALID_EVENTS: frozenset[str]#

‘frozenset(…)’

class bridge.training.callbacks.CallbackContext#

Context passed to callbacks.

Contains framework state and a persistent user_state dict. Modifying framework objects is at the user’s own risk.

.. attribute:: state

Global training state (config, train_state, timers, loggers).

.. attribute:: model

List of model chunks.

.. attribute:: user_state

Mutable dict for storing user data across callback invocations.

.. attribute:: optimizer

Optimizer instance. Available during training events only.

.. attribute:: scheduler

Learning rate scheduler. Available during training events only.

.. attribute:: loss_dict

Reduced losses from training step. Available in on_train_step_end.

.. attribute:: grad_norm

Gradient norm. Available in on_train_step_end if computed.

.. attribute:: skipped_iter

Whether the iteration was skipped. Available in on_train_step_end.

.. attribute:: total_loss_dict

Aggregated eval losses. Available in on_eval_end.

Field Availability by Event: All events: state, model, user_state Training events: optimizer, scheduler on_train_step_end: loss_dict, grad_norm, skipped_iter on_eval_end, on_test_end: total_loss_dict

state: megatron.bridge.training.state.GlobalState#

None

model: list[megatron.core.transformer.MegatronModule]#

None

user_state: dict#

‘field(…)’

optimizer: megatron.core.optimizer.MegatronOptimizer | None#

None

scheduler: megatron.core.optimizer_param_scheduler.OptimizerParamScheduler | None#

None

loss_dict: dict[str, torch.Tensor] | None#

None

grad_norm: float | None#

None

skipped_iter: bool | None#

None

total_loss_dict: dict[str, torch.Tensor] | None#

None

class bridge.training.callbacks.Callback#

Base class for organizing callbacks.

Subclass and override methods for events you want to handle. All methods are no-ops by default.

.. rubric:: Example

class MyCallback(Callback):
    def on_train_start(self, context):
        context.user_state['start_time'] = time.time()

    def on_train_end(self, context):
        elapsed = time.time() - context.user_state['start_time']
        print(f"Training took {elapsed:.2f}s")

pretrain(config, forward_step_func, callbacks=[MyCallback()])
on_train_start(
context: bridge.training.callbacks.CallbackContext,
) None#

Called after model.train(), before training loop begins.

on_train_step_start(
context: bridge.training.callbacks.CallbackContext,
) None#

Called at the start of each training step.

on_train_step_end(
context: bridge.training.callbacks.CallbackContext,
) None#

Called after each training step completes.

on_train_end(
context: bridge.training.callbacks.CallbackContext,
) None#

Called after training loop completes.

on_eval_start(
context: bridge.training.callbacks.CallbackContext,
) None#

Called after model.eval(), before evaluation loop begins.

on_eval_step_start(
context: bridge.training.callbacks.CallbackContext,
) None#

Called at the start of each evaluation step.

on_eval_step_end(
context: bridge.training.callbacks.CallbackContext,
) None#

Called after each evaluation step completes.

on_eval_end(
context: bridge.training.callbacks.CallbackContext,
) None#

Called after evaluation completes, before model.train().

on_test_start(
context: bridge.training.callbacks.CallbackContext,
) None#

Called after model.eval(), before test loop begins.

on_test_step_start(
context: bridge.training.callbacks.CallbackContext,
) None#

Called at the start of each test step.

on_test_step_end(
context: bridge.training.callbacks.CallbackContext,
) None#

Called after each test step completes.

on_test_end(
context: bridge.training.callbacks.CallbackContext,
) None#

Called after test completes, before model.train().

class bridge.training.callbacks.CallbackManager#

Manages registration and execution of training callbacks.

Supports two registration patterns:

  1. Class-based: Use add() with Callback subclass instances

    manager.add(MyCallback())
    manager.add([CallbackA(), CallbackB()])
    
  2. Functional: Use register() with event name and callable

    manager.register("on_train_start", my_function)
    

Both patterns can be mixed. Callbacks fire in registration order.

The manager also owns a user_state dictionary that persists across all callback invocations, allowing callbacks to share state.

.. rubric:: Example

manager = CallbackManager()
manager.add(MyCallback())
manager.register("on_eval_end", lambda ctx: print("Eval done!"))
pretrain(config, forward_step_func, callbacks=manager)

Initialization

Initialize the callback manager with empty callback lists and user state.

property user_state: dict#

Mutable dictionary for storing user data across callback invocations.

add(
callback: bridge.training.callbacks.Callback | list[bridge.training.callbacks.Callback],
) None#

Register one or more Callback instances.

Scans for methods that override the Callback base class and registers them to their corresponding events.

Parameters:

callback – Single Callback instance or list of Callback instances.

.. rubric:: Example

manager.add(MyCallback())
manager.add([TimingCallback(), LoggingCallback()])
register(
event_name: str,
fn: collections.abc.Callable[[bridge.training.callbacks.CallbackContext], None],
) None#

Register a callback function for a specific event.

Parameters:
  • event_name

    Event to register for. Valid events:

    • ”on_train_start”

    • ”on_train_step_start”

    • ”on_train_step_end”

    • ”on_train_end”

    • ”on_eval_start”

    • ”on_eval_step_start”

    • ”on_eval_step_end”

    • ”on_eval_end”

    • ”on_test_start”

    • ”on_test_step_start”

    • ”on_test_step_end”

    • ”on_test_end”

  • fn – Callback function with signature (CallbackContext) -> None.

Raises:

ValueError – If event_name is not valid.

.. rubric:: Example

manager.register("on_train_step_end", my_logging_fn)
property events: frozenset[str]#

Set of valid event names for registration.

list_callbacks(
event_name: str,
) list[collections.abc.Callable[[bridge.training.callbacks.CallbackContext], None]]#

Return list of callbacks registered for an event.

Parameters:

event_name – Name of the event.

Returns:

List of registered callables (in execution order).

Raises:

ValueError – If event_name is not valid.

has_callbacks(event_name: str) bool#

Check if any callbacks are registered for an event.

Parameters:

event_name – Name of the event.

Returns:

True if at least one callback is registered for the event.

fire(
event_name: str,
context: bridge.training.callbacks.CallbackContext,
) None#

Execute all callbacks for an event.

Exceptions from callbacks propagate to the caller.

Parameters:
  • event_name – Name of the event to fire.

  • context – CallbackContext to pass to callbacks.

bridge.training.callbacks.normalize_callbacks(
callbacks: list[bridge.training.callbacks.Callback] | bridge.training.callbacks.CallbackManager | None,
) bridge.training.callbacks.CallbackManager | None#

Normalize callbacks argument to a CallbackManager.

This helper is used internally by pretrain() to accept multiple input formats.

Parameters:

callbacks – Either a list of Callback instances, a CallbackManager, or None.

Returns:

A CallbackManager instance, or None if callbacks was None.

bridge.training.callbacks.should_fire(
callback_manager: bridge.training.callbacks.CallbackManager | None,
event_name: str,
) bool#

Check if callbacks should be fired for an event.

Combines the None check and has_callbacks check into a single call.

Parameters:
  • callback_manager – The callback manager instance, or None.

  • event_name – Name of the event to check.

Returns:

True if callback_manager exists and has callbacks for the event.