nemo_automodel.components.training.ema

View as Markdown

Exponential moving average of a BAGEL model’s parameters.

BAGEL’s training recipe (and diffusion-family models in general) uses an EMA copy of the model as the final saved checkpoint — averaging the noisy SGD trajectory typically yields a model with measurably better generation quality than the raw training endpoint. Upstream BAGEL uses decay=0.9999 and performs the update after every optimizer step (see train/fsdp_utils.py::fsdp_ema_update in upstream).

Update rule (per param, in-place):

ema = decay * ema + (1 - decay) * train

This module provides the math; FSDP2-aware wiring (sharded-tensor walking, checkpoint save/load through DCP) is layered on top by the recipe.

Module Contents

Classes

NameDescription
EMAManagerTracks an exponential moving average of model’s trainable parameters.
ShardedModelEMAManagerTracks EMA weights in a separately sharded model.

API

class nemo_automodel.components.training.ema.EMAManager(
model: torch.nn.Module,
decay: float = 0.9999
)

Tracks an exponential moving average of model’s trainable parameters.

The shadow params are stored in a dict keyed by the same parameter names as model.named_parameters(). Walking the model on each update means we don’t pin to a specific param-list ordering; sharded / re-wrapped models work as long as the names are stable.

Tensors are stored on the same device and dtype as the source params.

_shadow
Dict[str, Tensor] = {}
decay
= float(decay)
nemo_automodel.components.training.ema.EMAManager.__contains__(
name: str
) -> bool
nemo_automodel.components.training.ema.EMAManager.__len__() -> int
nemo_automodel.components.training.ema.EMAManager.load_state_dict(
state: typing.Dict[str, torch.Tensor],
strict: bool = True
) -> None

Load shadow tensors from state. Shapes/dtypes must match.

nemo_automodel.components.training.ema.EMAManager.state_dict() -> typing.Dict[str, torch.Tensor]

Return the shadow tensors keyed by param name. Caller-owned copies.

nemo_automodel.components.training.ema.EMAManager.update(
model: torch.nn.Module
) -> None

Apply one EMA update step using model’s current parameters.

class nemo_automodel.components.training.ema.ShardedModelEMAManager(
ema_model: torch.nn.Module,
train_model: torch.nn.Module,
decay: float = 0.9999
)

Tracks EMA weights in a separately sharded model.

Upstream BAGEL keeps EMA as a frozen model that is FSDP-wrapped like the train model. That keeps the EMA footprint sharded instead of materializing a dense shadow copy on every rank.

_param_pairs
_tracked_names
decay
= float(decay)
nemo_automodel.components.training.ema.ShardedModelEMAManager.__contains__(
name: str
) -> bool
nemo_automodel.components.training.ema.ShardedModelEMAManager.__len__() -> int
nemo_automodel.components.training.ema.ShardedModelEMAManager.load_state_dict(
state: typing.Dict[str, torch.Tensor],
strict: bool = True
) -> None

Load EMA model state.

nemo_automodel.components.training.ema.ShardedModelEMAManager.state_dict() -> typing.Dict[str, torch.Tensor]

Return the EMA model state dict for DCP-backed checkpointing.

nemo_automodel.components.training.ema.ShardedModelEMAManager.update(
model: torch.nn.Module
) -> None

Apply one EMA update step using model’s current sharded params.