bridge.training.mimo_step#

MIMO-specific forward step function for use with pipeline schedules.

This module provides the forward step function for MIMO model training. Key design notes (per PR 3212):

  • The schedule expects dict-based outputs: {module_name: tensor} instead of single tensors

  • The MimoModel’s forward returns output tensors that the schedule sends via MultiModulePipelineCommunicator

  • The schedule’s backward_step_multimodule() handles dict-based backward pass automatically

  • Only the LLM module produces a loss - encoders just produce activations

Module Contents#

Functions#

loss_func

Loss function for MIMO model training.

get_batch

Get batch from data iterator.

forward_step

Forward step for MIMO model training.

Data#

API#

bridge.training.mimo_step.logger#

‘getLogger(…)’

bridge.training.mimo_step.loss_func(
loss_mask: torch.Tensor,
output_tensor: torch.Tensor,
) Tuple#

Loss function for MIMO model training.

Called at the terminal stage (LLM’s last PP stage).

Parameters:
  • loss_mask – Mask indicating which tokens contribute to the loss.

  • output_tensor – Model output tensor (losses per token).

Returns:

reporting_loss}).

Return type:

Tuple of (total_loss, num_tokens, {‘lm loss’

.. note::

Only the LLM module produces a loss. Encoders produce activations that are consumed by the LLM, but don’t have their own loss.

bridge.training.mimo_step.get_batch(
data_iterator: Iterable,
) Optional[Dict[str, torch.Tensor]]#

Get batch from data iterator.

Returns dict with:

  • input_ids, labels, loss_mask, position_ids (for LLM)

  • modality_inputs: {modality_name: preprocessed_tensors} (for encoders)

Uses existing MimoDataset format from Phase 3.

Parameters:

data_iterator – Iterator over the dataset.

Returns:

Batch dictionary or None if iterator is exhausted.

bridge.training.mimo_step.forward_step(
state: megatron.bridge.training.state.GlobalState,
data_iterator: Iterable,
model: megatron.core.models.mimo.MimoModel,
) Tuple[torch.Tensor, Optional[functools.partial]]#

Forward step for MIMO model training.

Uses 3-arg signature with GlobalState for Bridge compatibility. The training loop wraps this with prepare_forward_step_func() which:

  • Injects GlobalState automatically if forward_step accepts it

  • Provides access to state.timers, state.cfg, state.train_state

The MimoModel handles dict-based tensor flow internally:

  • Encoder modules produce activations sent via BridgeCommunicator

  • LLM module receives encoder outputs and produces loss

At terminal stage: returns (loss_tensor, loss_func) At intermediate stages: returns (output_dict, None) - schedule handles communication

GUARDRAIL: At last stage, assert output is scalar tensor (not dict) to catch misconfigurations early with a clear error message.

Parameters:
  • state – GlobalState containing timers, config, train_state.

  • data_iterator – Iterator over the dataset.

  • model – MimoModel instance.

Returns:

Tuple of (output_tensor, loss_function or None).