nemo_automodel.components.moe.router_replay

View as Markdown

Rollout Routing Replay (R3) for MoE policy-gradient training.

In on-policy RL on a Mixture-of-Experts model, the rollout (inference) engine and the training engine compute the router’s top-k expert selection independently. Numerical differences between the two backends flip a small fraction of routing decisions per layer, which compounds across layers until most tokens are routed to a different set of experts than they were during rollout. That mismatch breaks the importance-sampling assumption behind GRPO/GSPO and destabilizes training.

Routing replay removes the mismatch by capturing the top-k expert selection during one forward pass (the rollout-equivalent forward) and replaying that exact selection during the training forward. Only the discrete selection is replayed: the router logits and their softmax/sigmoid are still recomputed from the live router weights, so the gradient continues to flow into the router. This mirrors Megatron-LM’s moe_enable_routing_replay integration.

Usage::

from nemo_automodel.components.moe.router_replay import RouterReplay

Capture the selection on the rollout-equivalent forward.

with RouterReplay.record(): model(batch) captured = RouterReplay.collect() # one tensor per MoE layer, in layer order

Replay it on the training forward over the same tokens.

with RouterReplay.replay(captured): loss = model(batch) loss.backward()

Each :class:Gate constructed with routing replay enabled owns one :class:RouterReplay instance and registers it in a process-global list at construction time. The global order is the construction order, which matches the layer order, so collect() and replay() line the per-layer tensors up by position. This assumes single-threaded model construction (the norm for recipe training); call :meth:RouterReplay.clear_registry before building a second model in the same process.

Module Contents

Classes

NameDescription
RouterReplayPer-gate handle that records or replays a single MoE layer’s top-k selection.
RouterReplayModeActive mode of a :class:RouterReplay instance.

Functions

NameDescription
replay_selectionRoute indices through router_replay when routing replay is enabled.

Data

__all__

API

class nemo_automodel.components.moe.router_replay.RouterReplay()

Per-gate handle that records or replays a single MoE layer’s top-k selection.

Instances register themselves in a process-global list on construction. The static helpers drive every registered instance at once so a caller toggles record/replay for the whole model with a single call (or the record / replay context managers).

_registry
List[RouterReplay] = []
mode
Optional[RouterReplayMode] = None
recorded_indices
Optional[Tensor] = None
target_indices
Optional[Tensor] = None
nemo_automodel.components.moe.router_replay.RouterReplay.apply(
indices: torch.Tensor
) -> torch.Tensor

Record or replay indices according to the current mode.

Parameters:

indices
torch.Tensor

The top-k expert indices the gate just selected, shape [num_tokens, topk].

Returns: torch.Tensor

indices unchanged when no mode is active or while recording; the

nemo_automodel.components.moe.router_replay.RouterReplay.clear() -> None

Drop both the recorded and the target selection for this layer.

nemo_automodel.components.moe.router_replay.RouterReplay.clear_indices() -> None
staticmethod

Drop recorded and target selections on every registered instance.

nemo_automodel.components.moe.router_replay.RouterReplay.clear_registry() -> None
staticmethod

Forget every registered instance (use between independently built models).

nemo_automodel.components.moe.router_replay.RouterReplay.collect() -> typing.List[torch.Tensor]
staticmethod

Collect the recorded selection from every registered instance, in layer order.

Raises:

  • RuntimeError: If any instance has no recorded selection (i.e. a forward pass was not run under :meth:record).
staticmethod

Return the registered instances in construction (layer) order.

nemo_automodel.components.moe.router_replay.RouterReplay.record() -> typing.Iterator[None]
classmethod

Record the top-k selection of every gate for the duration of the block.

nemo_automodel.components.moe.router_replay.RouterReplay.replay(
all_layers_indices: typing.List[torch.Tensor]
) -> typing.Iterator[None]
classmethod

Replay all_layers_indices (one tensor per layer) for the duration of the block.

Target selections are cleared on exit so a stale replay never leaks into a later forward pass.

nemo_automodel.components.moe.router_replay.RouterReplay.set_mode(
mode: typing.Optional[nemo_automodel.components.moe.router_replay.RouterReplayMode]
) -> None
staticmethod

Set the mode on every registered instance (None disables replay).

nemo_automodel.components.moe.router_replay.RouterReplay.set_replay_indices(
all_layers_indices: typing.List[torch.Tensor]
) -> None
staticmethod

Distribute one selection tensor per layer to the registered instances.

Parameters:

all_layers_indices
List[torch.Tensor]

One [num_tokens, topk] tensor per MoE layer, in the same order the layers were constructed.

Raises:

  • ValueError: If the number of tensors does not match the number of registered instances.
nemo_automodel.components.moe.router_replay.RouterReplay.set_target(
indices: torch.Tensor
) -> None

Set the selection to replay for this layer.

class nemo_automodel.components.moe.router_replay.RouterReplayMode

Bases: enum.Enum

Active mode of a :class:RouterReplay instance.

RECORD
= 'record'
REPLAY
= 'replay'
nemo_automodel.components.moe.router_replay.replay_selection(
router_replay: typing.Optional[nemo_automodel.components.moe.router_replay.RouterReplay],
indices: torch.Tensor
) -> torch.Tensor

Route indices through router_replay when routing replay is enabled.

Returns indices unchanged when router_replay is None (replay disabled) or when no mode is active, so the gate’s default path is a true no-op.

nemo_automodel.components.moe.router_replay.__all__ = ['RouterReplayMode', 'RouterReplay', 'replay_selection']