bridge.training.mimo_step#
MIMO-specific forward step function for use with pipeline schedules.
This module provides the forward step function for MIMO model training. Key design notes (per PR 3212):
The schedule expects dict-based outputs: {module_name: tensor} instead of single tensors
The MimoModel’s forward returns output tensors that the schedule sends via MultiModulePipelineCommunicator
The schedule’s backward_step_multimodule() handles dict-based backward pass automatically
Only the LLM module produces a loss - encoders just produce activations
Module Contents#
Functions#
Loss function for MIMO model training. |
|
Get batch from data iterator. |
|
Forward step for MIMO model training. |
Data#
API#
- bridge.training.mimo_step.logger#
‘getLogger(…)’
- bridge.training.mimo_step.loss_func(
- loss_mask: torch.Tensor,
- output_tensor: torch.Tensor,
Loss function for MIMO model training.
Called at the terminal stage (LLM’s last PP stage).
- Parameters:
loss_mask – Mask indicating which tokens contribute to the loss.
output_tensor – Model output tensor (losses per token).
- Returns:
reporting_loss}).
- Return type:
Tuple of (total_loss, num_tokens, {‘lm loss’
.. note::
Only the LLM module produces a loss. Encoders produce activations that are consumed by the LLM, but don’t have their own loss.
- bridge.training.mimo_step.get_batch(
- data_iterator: Iterable,
Get batch from data iterator.
Returns dict with:
input_ids, labels, loss_mask, position_ids (for LLM)
modality_inputs: {modality_name: preprocessed_tensors} (for encoders)
Uses existing MimoDataset format from Phase 3.
- Parameters:
data_iterator – Iterator over the dataset.
- Returns:
Batch dictionary or None if iterator is exhausted.
- bridge.training.mimo_step.forward_step(
- state: megatron.bridge.training.state.GlobalState,
- data_iterator: Iterable,
- model: megatron.core.models.mimo.MimoModel,
Forward step for MIMO model training.
Uses 3-arg signature with GlobalState for Bridge compatibility. The training loop wraps this with prepare_forward_step_func() which:
Injects GlobalState automatically if forward_step accepts it
Provides access to state.timers, state.cfg, state.train_state
The MimoModel handles dict-based tensor flow internally:
Encoder modules produce activations sent via BridgeCommunicator
LLM module receives encoder outputs and produces loss
At terminal stage: returns (loss_tensor, loss_func) At intermediate stages: returns (output_dict, None) - schedule handles communication
GUARDRAIL: At last stage, assert output is scalar tensor (not dict) to catch misconfigurations early with a clear error message.
- Parameters:
state – GlobalState containing timers, config, train_state.
data_iterator – Iterator over the dataset.
model – MimoModel instance.
- Returns:
Tuple of (output_tensor, loss_function or None).