nemo_automodel.components.training.ema
nemo_automodel.components.training.ema
Exponential moving average of a BAGEL model’s parameters.
BAGEL’s training recipe (and diffusion-family models in general) uses an EMA
copy of the model as the final saved checkpoint — averaging the noisy SGD
trajectory typically yields a model with measurably better generation quality
than the raw training endpoint. Upstream BAGEL uses decay=0.9999 and
performs the update after every optimizer step (see
train/fsdp_utils.py::fsdp_ema_update in upstream).
Update rule (per param, in-place):
ema = decay * ema + (1 - decay) * train
This module provides the math; FSDP2-aware wiring (sharded-tensor walking, checkpoint save/load through DCP) is layered on top by the recipe.
Module Contents
Classes
API
Tracks an exponential moving average of model’s trainable parameters.
The shadow params are stored in a dict keyed by the same parameter names
as model.named_parameters(). Walking the model on each update means we
don’t pin to a specific param-list ordering; sharded / re-wrapped models
work as long as the names are stable.
Tensors are stored on the same device and dtype as the source params.
Load shadow tensors from state. Shapes/dtypes must match.
Return the shadow tensors keyed by param name. Caller-owned copies.
Apply one EMA update step using model’s current parameters.
Tracks EMA weights in a separately sharded model.
Upstream BAGEL keeps EMA as a frozen model that is FSDP-wrapped like the train model. That keeps the EMA footprint sharded instead of materializing a dense shadow copy on every rank.
Load EMA model state.
Return the EMA model state dict for DCP-backed checkpointing.
Apply one EMA update step using model’s current sharded params.