> For clean Markdown of any page, append .md to the page URL.
> For a complete documentation index, see https://docs.nvidia.com/nemo/automodel/llms.txt.
> For AI client integration (Claude Code, Cursor, etc.), connect to the MCP server at https://docs.nvidia.com/nemo/automodel/_mcp/server.

# nemo_automodel.components.training.ema

Exponential moving average of a BAGEL model's parameters.

BAGEL's training recipe (and diffusion-family models in general) uses an EMA
copy of the model as the final saved checkpoint — averaging the noisy SGD
trajectory typically yields a model with measurably better generation quality
than the raw training endpoint. Upstream BAGEL uses `decay=0.9999` and
performs the update after every optimizer step (see
`train/fsdp_utils.py::fsdp_ema_update` in upstream).

Update rule (per param, in-place):

ema = decay \* ema + (1 - decay) \* train

This module provides the math; FSDP2-aware wiring (sharded-tensor walking,
checkpoint save/load through DCP) is layered on top by the recipe.

## Module Contents

### Classes

| Name                                                                                       | Description                                                             |
| ------------------------------------------------------------------------------------------ | ----------------------------------------------------------------------- |
| [`EMAManager`](#nemo_automodel-components-training-ema-EMAManager)                         | Tracks an exponential moving average of `model`'s trainable parameters. |
| [`ShardedModelEMAManager`](#nemo_automodel-components-training-ema-ShardedModelEMAManager) | Tracks EMA weights in a separately sharded model.                       |

### API

```python
class nemo_automodel.components.training.ema.EMAManager(
    model: torch.nn.Module,
    decay: float = 0.9999
)
```

Tracks an exponential moving average of `model`'s trainable parameters.

The shadow params are stored in a dict keyed by the same parameter names
as `model.named_parameters()`. Walking the model on each update means we
don't pin to a specific param-list ordering; sharded / re-wrapped models
work as long as the names are stable.

Tensors are stored on the same device and dtype as the source params.

```python
nemo_automodel.components.training.ema.EMAManager.__contains__(
    name: str
) -> bool
```

```python
nemo_automodel.components.training.ema.EMAManager.__len__() -> int
```

```python
nemo_automodel.components.training.ema.EMAManager.load_state_dict(
    state: typing.Dict[str, torch.Tensor],
    strict: bool = True
) -> None
```

Load shadow tensors from `state`. Shapes/dtypes must match.

```python
nemo_automodel.components.training.ema.EMAManager.state_dict() -> typing.Dict[str, torch.Tensor]
```

Return the shadow tensors keyed by param name. Caller-owned copies.

```python
nemo_automodel.components.training.ema.EMAManager.update(
    model: torch.nn.Module
) -> None
```

Apply one EMA update step using `model`'s current parameters.

```python
class nemo_automodel.components.training.ema.ShardedModelEMAManager(
    ema_model: torch.nn.Module,
    train_model: torch.nn.Module,
    decay: float = 0.9999
)
```

Tracks EMA weights in a separately sharded model.

Upstream BAGEL keeps EMA as a frozen model that is FSDP-wrapped like the
train model. That keeps the EMA footprint sharded instead of materializing
a dense shadow copy on every rank.

```python
nemo_automodel.components.training.ema.ShardedModelEMAManager.__contains__(
    name: str
) -> bool
```

```python
nemo_automodel.components.training.ema.ShardedModelEMAManager.__len__() -> int
```

```python
nemo_automodel.components.training.ema.ShardedModelEMAManager.load_state_dict(
    state: typing.Dict[str, torch.Tensor],
    strict: bool = True
) -> None
```

Load EMA model state.

```python
nemo_automodel.components.training.ema.ShardedModelEMAManager.state_dict() -> typing.Dict[str, torch.Tensor]
```

Return the EMA model state dict for DCP-backed checkpointing.

```python
nemo_automodel.components.training.ema.ShardedModelEMAManager.update(
    model: torch.nn.Module
) -> None
```

Apply one EMA update step using `model`'s current sharded params.