Callbacks#
Megatron Bridge provides a lightweight callback system for injecting custom logic into the training and evaluation loop without modifying framework code. This is ideal for propietary integrations or custom logging and metrics tracking.
Quick Start#
Class-Based Callbacks#
Subclass bridge.training.callbacks.Callback and override event methods:
import time
from megatron.bridge.training.callbacks import Callback
from megatron.bridge.training.gpt_step import forward_step
from megatron.bridge.training.pretrain import pretrain
from megatron.bridge.recipes.qwen import qwen25_500m_pretrain_config
class MyCallback(Callback):
def on_train_start(self, context):
context.user_state['start_time'] = time.time()
print(f"Training started at step {context.state.train_state.step}")
def on_train_step_end(self, context):
if context.loss_dict:
print(f"Step {context.state.train_state.step}: loss={context.loss_dict}")
def on_train_end(self, context):
elapsed = time.time() - context.user_state['start_time']
print(f"Training completed in {elapsed:.2f}s")
# Create a config that fits on a single GPU
config = qwen25_500m_pretrain_config()
# Pass callbacks to pretrain
pretrain(config, forward_step, callbacks=[MyCallback()])
Functional Callbacks#
Register functions directly with bridge.training.callbacks.CallbackManager:
from megatron.bridge.training.callbacks import CallbackManager
from megatron.bridge.training.gpt_step import forward_step
from megatron.bridge.training.pretrain import pretrain
from megatron.bridge.recipes.qwen import qwen25_500m_pretrain_config
def log_step(context):
step = context.state.train_state.step
if context.loss_dict:
print(f"Step {step}: {context.loss_dict}")
callback_manager = CallbackManager()
callback_manager.register("on_train_step_end", log_step)
# Create a config that fits on a single GPU
config = qwen25_500m_pretrain_config()
pretrain(config, forward_step, callbacks=callback_manager)
Mixing Both Patterns#
Both registration patterns can be combined:
from megatron.bridge.training.callbacks import CallbackManager
from megatron.bridge.training.gpt_step import forward_step
from megatron.bridge.training.pretrain import pretrain
from megatron.bridge.recipes.qwen import qwen25_500m_pretrain_config
manager = CallbackManager()
manager.add(MyCallback())
manager.add([TimingCallback(), MetricsCallback()])
manager.register("on_eval_end", lambda ctx: print("Evaluation complete!"))
# Create a config that fits on a single GPU
config = qwen25_500m_pretrain_config()
pretrain(config, forward_step, callbacks=manager)
Available Events#
Training Events#
Event |
When Fired |
Available Context Fields |
|---|---|---|
|
After |
|
|
Before each training step |
|
|
After each training step |
|
|
After training loop completes |
|
Validation Events#
Event |
When Fired |
Available Context Fields |
|---|---|---|
|
After |
|
|
Before each validation step |
|
|
After each validation step |
|
|
After validation completes |
|
Test Events#
Event |
When Fired |
Available Context Fields |
|---|---|---|
|
After |
|
|
Before each test step |
|
|
After each test step |
|
|
After test completes |
|
CallbackContext#
The bridge.training.callbacks.CallbackContext provides access to framework state:
Always Available#
state:bridge.training.state.GlobalState- Contains config, train_state, timers, and loggersmodel: List of model chunksuser_state: Mutable dict for storing data across callback invocations
Training Events Only#
optimizer: The optimizer instancescheduler: Learning rate scheduler
Event-Specific Fields#
loss_dict(on_train_step_end): Dictionary of reduced losses from the training stepgrad_norm(on_train_step_end): Gradient norm (if computed)skipped_iter(on_train_step_end): Whether the iteration was skippedtotal_loss_dict(on_eval_end,on_test_end): Aggregated evaluation/test losses
User State#
The CallbackManager owns a user_state dictionary that persists across all callback invocations during a training run. Use it to share data between callbacks or accumulate metrics:
class StepCounterCallback(Callback):
def on_train_start(self, context):
context.user_state['callback_step_count'] = 0
def on_train_step_end(self, context):
context.user_state['callback_step_count'] += 1
def on_train_end(self, context):
print(f"Callback saw {context.user_state['callback_step_count']} steps")
Distributed Training#
Callbacks fire on all ranks without framework-level synchronization. If your callback should only run on specific ranks, add guards:
import torch.distributed as dist
class RankZeroCallback(Callback):
def on_train_step_end(self, context):
if dist.get_rank() == 0:
print(f"Step {context.state.train_state.step} complete")
Exception Handling#
Exceptions from callbacks propagate to the caller. The framework does not catch or handle callback exceptions. If your callback might fail, wrap it in a try-except:
def safe_callback(context):
try:
# Your logic here
external_service.log(context.loss_dict)
except Exception as e:
print(f"Callback failed: {e}")
# Don't re-raise to avoid stopping training
Execution Order#
Callbacks fire in registration order:
Callbacks added via
add()fire in the order they were addedCallbacks registered via
register()fire in the order they were registeredIf both methods are used, the order depends on when each was called
Introspection#
Query registered callbacks:
manager = CallbackManager()
manager.register("on_train_start", my_fn)
# Check if any callbacks exist for an event
if manager.has_callbacks("on_train_start"):
print("Callbacks registered for on_train_start")
# List all callbacks for an event
callbacks = manager.list_callbacks("on_train_start")
print(f"Found {len(callbacks)} callbacks")
# Get all valid event names
print(manager.events) # frozenset of valid event names
Design Principles#
The callback system follows these principles:
First-Party Isolation: Framework code never uses callbacks for its own logic. Callbacks are strictly for third-party extensions.
Zero Overhead: When no callbacks are registered, there is zero performance overhead.
Safety: Callbacks receive framework state but modifying it is at the user’s own risk. The framework makes no guarantees about the effects of modifications.
Examples#
Proprietary Metrics#
class ProprietaryMetricsCallback(Callback):
"""Send metrics to internal monitoring system."""
def __init__(self, endpoint: str):
self.client = InternalMetricsClient(endpoint)
def on_train_step_end(self, context):
if context.loss_dict:
self.client.send({
"step": context.state.train_state.step,
"loss": context.loss_dict.get("lm loss"),
"grad_norm": context.grad_norm,
"cluster_id": os.environ.get("CLUSTER_ID"),
})