nemo_automodel.components.moe.router_replay
nemo_automodel.components.moe.router_replay
Rollout Routing Replay (R3) for MoE policy-gradient training.
In on-policy RL on a Mixture-of-Experts model, the rollout (inference) engine and the training engine compute the router’s top-k expert selection independently. Numerical differences between the two backends flip a small fraction of routing decisions per layer, which compounds across layers until most tokens are routed to a different set of experts than they were during rollout. That mismatch breaks the importance-sampling assumption behind GRPO/GSPO and destabilizes training.
Routing replay removes the mismatch by capturing the top-k expert selection during
one forward pass (the rollout-equivalent forward) and replaying that exact selection
during the training forward. Only the discrete selection is replayed: the router
logits and their softmax/sigmoid are still recomputed from the live router weights,
so the gradient continues to flow into the router. This mirrors Megatron-LM’s
moe_enable_routing_replay integration.
Usage::
from nemo_automodel.components.moe.router_replay import RouterReplay
Capture the selection on the rollout-equivalent forward.
with RouterReplay.record(): model(batch) captured = RouterReplay.collect() # one tensor per MoE layer, in layer order
Replay it on the training forward over the same tokens.
with RouterReplay.replay(captured): loss = model(batch) loss.backward()
Each :class:Gate constructed with routing replay enabled owns one
:class:RouterReplay instance and registers it in a process-global list at
construction time. The global order is the construction order, which matches the
layer order, so collect() and replay() line the per-layer tensors up by
position. This assumes single-threaded model construction (the norm for recipe
training); call :meth:RouterReplay.clear_registry before building a second
model in the same process.
Module Contents
Classes
Functions
Data
API
Per-gate handle that records or replays a single MoE layer’s top-k selection.
Instances register themselves in a process-global list on construction. The
static helpers drive every registered instance at once so a caller toggles
record/replay for the whole model with a single call (or the record /
replay context managers).
Record or replay indices according to the current mode.
Parameters:
The top-k expert indices the gate just selected, shape
[num_tokens, topk].
Returns: torch.Tensor
indices unchanged when no mode is active or while recording; the
Drop both the recorded and the target selection for this layer.
Drop recorded and target selections on every registered instance.
Forget every registered instance (use between independently built models).
Collect the recorded selection from every registered instance, in layer order.
Raises:
RuntimeError: If any instance has no recorded selection (i.e. a forward pass was not run under :meth:record).
Return the registered instances in construction (layer) order.
Record the top-k selection of every gate for the duration of the block.
Replay all_layers_indices (one tensor per layer) for the duration of the block.
Target selections are cleared on exit so a stale replay never leaks into a later forward pass.
Set the mode on every registered instance (None disables replay).
Distribute one selection tensor per layer to the registered instances.
Parameters:
One [num_tokens, topk] tensor per MoE layer, in
the same order the layers were constructed.
Raises:
ValueError: If the number of tensors does not match the number of registered instances.
Set the selection to replay for this layer.
Bases: enum.Enum
Active mode of a :class:RouterReplay instance.
Route indices through router_replay when routing replay is enabled.
Returns indices unchanged when router_replay is None (replay disabled)
or when no mode is active, so the gate’s default path is a true no-op.