bridge.training.forward_step_func_types#

Type definitions for forward step function definitions.

This module provides comprehensive type definitions for forward step functions used in Megatron Bridge training. Forward step functions are the core of the training loop, responsible for performing a single forward pass and returning both the output tensor and a loss function.

Key Types: - ForwardStepCallable: Union of all supported forward step signatures (functions + functors) - LossFunction: The partial function returned by forward step functions - LossFunctionReturn: The possible return types when calling a loss function

Example Usage: >>> from functools import partial >>> from megatron.bridge.training.state import GlobalState >>> >>> def my_forward_step(state: GlobalState, data_iterator, model, return_schedule_plan=False): … # Access configuration, timers, and training state … timers = state.timers … config = state.cfg … … # Get batch data … batch = next(data_iterator) … … # Forward pass with timing … timers(“forward-step”).start() … output_tensor = model(batch[‘input_ids’]) … timers(“forward-step”).stop() … … # Create loss function … def loss_func(output_tensor): … loss = compute_loss(output_tensor, batch[‘labels’]) … num_tokens = batch[‘labels’].numel() … loss_reduced = {“lm_loss”: loss.detach()} … return loss, num_tokens, loss_reduced # ThreeTupleLossReturn … … return output_tensor, partial(loss_func) … >>> # State injection is automatic - no manual binding needed! >>> pretrain(config, my_forward_step) >>> >>> # Functor example (for stateful forward steps) >>> class StatefulForwardStep: … def init(self, loss_scale: float = 1.0): … self.loss_scale = loss_scale … self.step_count = 0 … … def call(self, state: GlobalState, data_iterator, model, return_schedule_plan=False): … self.step_count += 1 … # … forward step logic with state tracking … … return output_tensor, partial(loss_func) … >>> functor = StatefulForwardStep(loss_scale=2.0) >>> pretrain(config, functor)

Module Contents#

Classes#

TwoArgForwardStep

Protocol for forward step functions with 2 arguments.

ThreeArgStateForwardStep

Protocol for forward step functions with 3 arguments including state.

ThreeArgForwardStep

Protocol for forward step functions with 3 arguments.

FourArgForwardStep

Protocol for forward step functions with 4 arguments.

ForwardStepFunctor

Protocol for forward step functors (callable classes).

Data#

API#

bridge.training.forward_step_func_types.LossReduced#

None

bridge.training.forward_step_func_types.TwoTupleLossReturn#

None

bridge.training.forward_step_func_types.ThreeTupleLossReturn#

None

bridge.training.forward_step_func_types.InferenceLossReturn#

None

bridge.training.forward_step_func_types.LossFunctionReturn#

None

bridge.training.forward_step_func_types.LossFunction#

None

class bridge.training.forward_step_func_types.TwoArgForwardStep#

Bases: typing.Protocol

Protocol for forward step functions with 2 arguments.

This represents forward step functions that don’t need access to GlobalState and don’t support schedule plan return mode.

Parameters:
  • data_iterator – Iterator providing training data batches

  • model – The GPT model to train

Returns:

Tuple of (output_tensor, loss_function)

__call__(
data_iterator: Iterable,
model: megatron.core.models.gpt.GPTModel,
) tuple[torch.Tensor, bridge.training.forward_step_func_types.LossFunction]#
class bridge.training.forward_step_func_types.ThreeArgStateForwardStep#

Bases: typing.Protocol

Protocol for forward step functions with 3 arguments including state.

This represents forward step functions that need access to GlobalState but don’t support schedule plan return mode.

Parameters:
  • state – Global training state containing configuration and runtime objects

  • data_iterator – Iterator providing training data batches

  • model – The GPT model to train

Returns:

Tuple of (output_tensor, loss_function)

__call__(
state: megatron.bridge.training.state.GlobalState,
data_iterator: Iterable,
model: megatron.core.models.gpt.GPTModel,
) tuple[torch.Tensor, bridge.training.forward_step_func_types.LossFunction]#
class bridge.training.forward_step_func_types.ThreeArgForwardStep#

Bases: typing.Protocol

Protocol for forward step functions with 3 arguments.

This represents forward step functions that don’t need access to GlobalState but support schedule plan return mode. These are typically 4-arg functions that have had GlobalState pre-bound via functools.partial.

Parameters:
  • data_iterator – Iterator providing training data batches

  • model – The GPT model to train

  • return_schedule_plan – Whether to return schedule plan instead of output tensor

Returns:

Tuple of (output_tensor, loss_function) or (schedule_plan, loss_function)

__call__(
data_iterator: Iterable,
model: megatron.core.models.gpt.GPTModel,
return_schedule_plan: bool = False,
) tuple[torch.Tensor, bridge.training.forward_step_func_types.LossFunction]#
class bridge.training.forward_step_func_types.FourArgForwardStep#

Bases: typing.Protocol

Protocol for forward step functions with 4 arguments.

This represents forward step functions that need access to GlobalState and support schedule plan return mode. These are the most complete forward step function signatures.

Parameters:
  • state – Global training state containing configuration and runtime objects

  • data_iterator – Iterator providing training data batches

  • model – The GPT model to train

  • return_schedule_plan – Whether to return schedule plan instead of output tensor

Returns:

Tuple of (output_tensor, loss_function) or (schedule_plan, loss_function)

__call__(
state: megatron.bridge.training.state.GlobalState,
data_iterator: Iterable,
model: megatron.core.models.gpt.GPTModel,
return_schedule_plan: bool = False,
) tuple[torch.Tensor, bridge.training.forward_step_func_types.LossFunction]#
class bridge.training.forward_step_func_types.ForwardStepFunctor#

Bases: typing.Protocol

Protocol for forward step functors (callable classes).

This protocol represents classes that implement call with one of the supported forward step function signatures. Functors are useful when you need to maintain state between forward step calls or implement complex forward step logic that benefits from object-oriented design.

The call method must match one of the supported signatures:

  • (data_iterator, model)

  • (data_iterator, model, return_schedule_plan=False) OR (state: GlobalState, data_iterator, model)

  • (state: GlobalState, data_iterator, model, return_schedule_plan=False)

RECOMMENDED: Use GlobalState type hint for automatic state injection and full access to configuration, timers, and training state.

.. rubric:: Examples

class MyForwardFunctor: … def init(self, loss_scale: float = 1.0): … self.loss_scale = loss_scale … self.call_count = 0 … … def call(self, state: GlobalState, data_iterator, model, return_schedule_plan=False): … self.call_count += 1 … # Access training infrastructure … timers = state.timers … config = state.cfg … # … forward step logic … … return output_tensor, loss_function … functor = MyForwardFunctor(loss_scale=2.0) pretrain(config, functor) # State injection is automatic!

__call__(
*args,
**kwargs,
) tuple[torch.Tensor, bridge.training.forward_step_func_types.LossFunction]#

Execute the forward step.

The actual implementation must match one of the overloaded signatures above. This fallback signature is required by the Protocol but should not be used directly - type checkers will use the @overload signatures for validation.

bridge.training.forward_step_func_types.ForwardStepFunc#

None

bridge.training.forward_step_func_types.ForwardStepCallable#

None