bridge.training.forward_step_func_types#
Type definitions for forward step function definitions.
This module provides comprehensive type definitions for forward step functions used in Megatron Bridge training. Forward step functions are the core of the training loop, responsible for performing a single forward pass and returning both the output tensor and a loss function.
Key Types: - ForwardStepCallable: Union of all supported forward step signatures (functions + functors) - LossFunction: The partial function returned by forward step functions - LossFunctionReturn: The possible return types when calling a loss function
Example Usage: >>> from functools import partial >>> from megatron.bridge.training.state import GlobalState >>> >>> def my_forward_step(state: GlobalState, data_iterator, model, return_schedule_plan=False): … # Access configuration, timers, and training state … timers = state.timers … config = state.cfg … … # Get batch data … batch = next(data_iterator) … … # Forward pass with timing … timers(“forward-step”).start() … output_tensor = model(batch[‘input_ids’]) … timers(“forward-step”).stop() … … # Create loss function … def loss_func(output_tensor): … loss = compute_loss(output_tensor, batch[‘labels’]) … num_tokens = batch[‘labels’].numel() … loss_reduced = {“lm_loss”: loss.detach()} … return loss, num_tokens, loss_reduced # ThreeTupleLossReturn … … return output_tensor, partial(loss_func) … >>> # State injection is automatic - no manual binding needed! >>> pretrain(config, my_forward_step) >>> >>> # Functor example (for stateful forward steps) >>> class StatefulForwardStep: … def init(self, loss_scale: float = 1.0): … self.loss_scale = loss_scale … self.step_count = 0 … … def call(self, state: GlobalState, data_iterator, model, return_schedule_plan=False): … self.step_count += 1 … # … forward step logic with state tracking … … return output_tensor, partial(loss_func) … >>> functor = StatefulForwardStep(loss_scale=2.0) >>> pretrain(config, functor)
Module Contents#
Classes#
Protocol for forward step functions with 2 arguments. |
|
Protocol for forward step functions with 3 arguments including state. |
|
Protocol for forward step functions with 3 arguments. |
|
Protocol for forward step functions with 4 arguments. |
|
Protocol for forward step functors (callable classes). |
Data#
API#
- bridge.training.forward_step_func_types.LossReduced#
None
- bridge.training.forward_step_func_types.TwoTupleLossReturn#
None
- bridge.training.forward_step_func_types.ThreeTupleLossReturn#
None
- bridge.training.forward_step_func_types.InferenceLossReturn#
None
- bridge.training.forward_step_func_types.LossFunctionReturn#
None
- bridge.training.forward_step_func_types.LossFunction#
None
- class bridge.training.forward_step_func_types.TwoArgForwardStep#
Bases:
typing.ProtocolProtocol for forward step functions with 2 arguments.
This represents forward step functions that don’t need access to GlobalState and don’t support schedule plan return mode.
- Parameters:
data_iterator – Iterator providing training data batches
model – The GPT model to train
- Returns:
Tuple of (output_tensor, loss_function)
- __call__(
- data_iterator: Iterable,
- model: megatron.core.models.gpt.GPTModel,
- class bridge.training.forward_step_func_types.ThreeArgStateForwardStep#
Bases:
typing.ProtocolProtocol for forward step functions with 3 arguments including state.
This represents forward step functions that need access to GlobalState but don’t support schedule plan return mode.
- Parameters:
state – Global training state containing configuration and runtime objects
data_iterator – Iterator providing training data batches
model – The GPT model to train
- Returns:
Tuple of (output_tensor, loss_function)
- __call__(
- state: megatron.bridge.training.state.GlobalState,
- data_iterator: Iterable,
- model: megatron.core.models.gpt.GPTModel,
- class bridge.training.forward_step_func_types.ThreeArgForwardStep#
Bases:
typing.ProtocolProtocol for forward step functions with 3 arguments.
This represents forward step functions that don’t need access to GlobalState but support schedule plan return mode. These are typically 4-arg functions that have had GlobalState pre-bound via functools.partial.
- Parameters:
data_iterator – Iterator providing training data batches
model – The GPT model to train
return_schedule_plan – Whether to return schedule plan instead of output tensor
- Returns:
Tuple of (output_tensor, loss_function) or (schedule_plan, loss_function)
- __call__(
- data_iterator: Iterable,
- model: megatron.core.models.gpt.GPTModel,
- return_schedule_plan: bool = False,
- class bridge.training.forward_step_func_types.FourArgForwardStep#
Bases:
typing.ProtocolProtocol for forward step functions with 4 arguments.
This represents forward step functions that need access to GlobalState and support schedule plan return mode. These are the most complete forward step function signatures.
- Parameters:
state – Global training state containing configuration and runtime objects
data_iterator – Iterator providing training data batches
model – The GPT model to train
return_schedule_plan – Whether to return schedule plan instead of output tensor
- Returns:
Tuple of (output_tensor, loss_function) or (schedule_plan, loss_function)
- __call__(
- state: megatron.bridge.training.state.GlobalState,
- data_iterator: Iterable,
- model: megatron.core.models.gpt.GPTModel,
- return_schedule_plan: bool = False,
- class bridge.training.forward_step_func_types.ForwardStepFunctor#
Bases:
typing.ProtocolProtocol for forward step functors (callable classes).
This protocol represents classes that implement call with one of the supported forward step function signatures. Functors are useful when you need to maintain state between forward step calls or implement complex forward step logic that benefits from object-oriented design.
The call method must match one of the supported signatures:
(data_iterator, model)
(data_iterator, model, return_schedule_plan=False) OR (state: GlobalState, data_iterator, model)
(state: GlobalState, data_iterator, model, return_schedule_plan=False)
RECOMMENDED: Use GlobalState type hint for automatic state injection and full access to configuration, timers, and training state.
.. rubric:: Examples
class MyForwardFunctor: … def init(self, loss_scale: float = 1.0): … self.loss_scale = loss_scale … self.call_count = 0 … … def call(self, state: GlobalState, data_iterator, model, return_schedule_plan=False): … self.call_count += 1 … # Access training infrastructure … timers = state.timers … config = state.cfg … # … forward step logic … … return output_tensor, loss_function … functor = MyForwardFunctor(loss_scale=2.0) pretrain(config, functor) # State injection is automatic!
- __call__(
- *args,
- **kwargs,
Execute the forward step.
The actual implementation must match one of the overloaded signatures above. This fallback signature is required by the Protocol but should not be used directly - type checkers will use the @overload signatures for validation.
- bridge.training.forward_step_func_types.ForwardStepFunc#
None
- bridge.training.forward_step_func_types.ForwardStepCallable#
None