> For clean Markdown of any page, append .md to the page URL.
> For a complete documentation index, see https://docs.nvidia.com/nemo/automodel/llms.txt.
> For AI client integration (Claude Code, Cursor, etc.), connect to the MCP server at https://docs.nvidia.com/nemo/automodel/_mcp/server.

# nemo_automodel.components.models.bagel.state_dict_adapter

State-dict adapter for BAGEL HF checkpoints.

BAGEL-7B-MoT ships two on-disk files with complementary key namespaces:

* `ema.safetensors` — everything under the top-level module tree, both the
  **UND** (understanding) path and the **GEN** (`*_moe_gen`) Mixture-of-
  Transformers siblings plus the flow-matching scaffolding
  (`time_embedder`, `vae2llm`, `llm2vae`, `latent_pos_embed`).
* `ae.safetensors` — the VAE encoder/decoder weights. These live in a
  *separate* file and are loaded by the Stage 2 recipe via upstream
  `load_ae`; they are not parameters of `BagelForUnifiedMultimodal`.

Special cases:

* `vit_pos_embed.pos_embed` *is* present in the checkpoint. Upstream BAGEL
  declares it as a frozen `nn.Parameter(requires_grad=False)` (sinusoidal
  2D position embedding) — so it serializes like any other parameter. It is
  classified as UND since it feeds the understanding-side connector.
* The released BAGEL-7B-MoT checkpoint stores
  `vit_model.vision_model.embeddings.patch_embedding.weight` in the
  post-conversion linear layout `(out_channels, in_channels * P * P)`. AM
  swaps the fresh `Conv2d` module to `Linear` before loading that
  checkpoint so the tensor shape matches directly.
* `embed_tokens` / `lm_head` / the final `norm` are **UND-side tensors
  that are logically read by the GEN path** (GEN tokens use text embeddings
  and the shared LM head). No physical tensor sharing is required in the
  checkpoint — the module tree uses references, not copies.

Stage 1 loads the UND subset only (`UND_PATTERNS`). Stage 2 additionally
loads the GEN subset (`GEN_PATTERNS`). VAE keys are recognized only so that
accidentally merged checkpoints can be reported cleanly instead of being
treated as unknown model weights.

## Module Contents

### Classes

| Name                                                                                                        | Description                                  |
| ----------------------------------------------------------------------------------------------------------- | -------------------------------------------- |
| [`BagelStateDictAdapter`](#nemo_automodel-components-models-bagel-state_dict_adapter-BagelStateDictAdapter) | HF \<-> NeMo state-dict converter for BAGEL. |

### Functions

| Name                                                                                                                              | Description                                                              |
| --------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------ |
| [`_compile`](#nemo_automodel-components-models-bagel-state_dict_adapter-_compile)                                                 | -                                                                        |
| [`_matches_any`](#nemo_automodel-components-models-bagel-state_dict_adapter-_matches_any)                                         | -                                                                        |
| [`_normalize_stage`](#nemo_automodel-components-models-bagel-state_dict_adapter-_normalize_stage)                                 | Normalize `stage` to one of `"stage1"` or `"stage2"`.                    |
| [`_partition`](#nemo_automodel-components-models-bagel-state_dict_adapter-_partition)                                             | Partition a flat checkpoint dict into UND / GEN / VAE / unknown buckets. |
| [`load_bagel_checkpoint_state_dict`](#nemo_automodel-components-models-bagel-state_dict_adapter-load_bagel_checkpoint_state_dict) | Load a BAGEL HF checkpoint directory into a NeMo-layout state dict.      |

### Data

[`GEN_PATTERNS`](#nemo_automodel-components-models-bagel-state_dict_adapter-GEN_PATTERNS)

[`SHARED_PATTERNS`](#nemo_automodel-components-models-bagel-state_dict_adapter-SHARED_PATTERNS)

[`UND_PATTERNS`](#nemo_automodel-components-models-bagel-state_dict_adapter-UND_PATTERNS)

[`VAE_PATTERNS`](#nemo_automodel-components-models-bagel-state_dict_adapter-VAE_PATTERNS)

[`_GEN_RES`](#nemo_automodel-components-models-bagel-state_dict_adapter-_GEN_RES)

[`_UND_RES`](#nemo_automodel-components-models-bagel-state_dict_adapter-_UND_RES)

[`_VAE_RES`](#nemo_automodel-components-models-bagel-state_dict_adapter-_VAE_RES)

[`logger`](#nemo_automodel-components-models-bagel-state_dict_adapter-logger)

### API

```python
class nemo_automodel.components.models.bagel.state_dict_adapter.BagelStateDictAdapter(
    config: typing.Any = None,
    stage: typing.Any = 'stage1'
)
```

**Bases:** [StateDictAdapter](/nemo-automodel/nemo_automodel/components/checkpoint/state_dict_adapter#nemo_automodel-components-checkpoint-state_dict_adapter-StateDictAdapter)

HF \<-> NeMo state-dict converter for BAGEL.

Stage 1 returns only the UND subset from `from_hf`. Stage 2 additionally
returns the GEN subset. The VAE remains outside this module and is loaded
separately by the training recipe.

Because `BagelForUnifiedMultimodal` wraps the upstream BAGEL module
tree under `self.model`, HF checkpoint keys are unrooted
(`language_model...`) while native AM keys are rooted
(`model.language_model...`). The adapter filters upstream UND/GEN keys
and handles this root mapping.

**Parameters:**

`BagelConfig` (or `None`; currently only used for log
context — no shape sanity checks yet).

Default stage used when `from_hf` is called without an
explicit `stage` kwarg. Accepts `"stage1"` / `"stage2"` or
`1` / `2`.

```python
nemo_automodel.components.models.bagel.state_dict_adapter.BagelStateDictAdapter._hf_to_nemo_key(
    hf_key: str
) -> str
```

```python
nemo_automodel.components.models.bagel.state_dict_adapter.BagelStateDictAdapter._nemo_to_hf_key(
    nemo_key: str
) -> str
```

```python
nemo_automodel.components.models.bagel.state_dict_adapter.BagelStateDictAdapter._strip_nemo_root(
    key: str
) -> str
```

```python
nemo_automodel.components.models.bagel.state_dict_adapter.BagelStateDictAdapter.convert_single_tensor_to_hf(
    fqn: str,
    tensor: 'torch.Tensor',
    kwargs: typing.Any = {}
) -> list[tuple[str, 'torch.Tensor']]
```

Return `[(hf_fqn, tensor)]` for a single NeMo tensor.

Identity for BAGEL checkpoint keys.

```python
nemo_automodel.components.models.bagel.state_dict_adapter.BagelStateDictAdapter.from_hf(
    hf_state_dict: dict[str, 'torch.Tensor'],
    device_mesh: typing.Optional['DeviceMesh'] = None,
    stage: typing.Optional[typing.Any] = None,
    strict: bool = True,
    kwargs: typing.Any = {}
) -> dict[str, 'torch.Tensor']
```

Convert an HF-layout BAGEL state dict to the NeMo module-tree layout.

**Parameters:**

Flat dict loaded from `ema.safetensors`. If a
caller accidentally passes merged VAE keys, they are recognized
and excluded from the model state dict.

Unused for BAGEL (no expert parallelism yet); kept
for base-class signature compatibility.

`"stage1"` keeps only UND keys. `"stage2"` keeps UND
and GEN keys. Defaults to `self.stage`.

If `True` (default), raise `KeyError` on any key that
matches no known pattern. If `False`, log and drop.

**Returns:** `dict[str, 'torch.Tensor']`

Filtered state dict in NeMo layout, ready to feed into

**Raises:**

* `KeyError`: When `strict=True` and one or more input keys match
  no UND/GEN/VAE pattern.

```python
nemo_automodel.components.models.bagel.state_dict_adapter.BagelStateDictAdapter.to_hf(
    state_dict: dict[str, 'torch.Tensor'],
    kwargs: typing.Any = {}
) -> dict[str, 'torch.Tensor']
```

Convert a NeMo-layout state dict back to the HF BAGEL layout.

The VAE is not part of this module tree and should be saved/loaded
separately.

```python
nemo_automodel.components.models.bagel.state_dict_adapter._compile(
    patterns: list[str]
) -> list[re.Pattern]
```

```python
nemo_automodel.components.models.bagel.state_dict_adapter._matches_any(
    key: str,
    patterns: list[re.Pattern]
) -> bool
```

```python
nemo_automodel.components.models.bagel.state_dict_adapter._normalize_stage(
    stage: typing.Any
) -> str
```

Normalize `stage` to one of `"stage1"` or `"stage2"`.

Accepts either strings (`"stage1"` / `"stage2"`) or integers
(`1` / `2`) for backward compatibility with earlier callers.

```python
nemo_automodel.components.models.bagel.state_dict_adapter._partition(
    state_dict: dict[str, typing.Any]
) -> dict[str, dict[str, typing.Any]]
```

Partition a flat checkpoint dict into UND / GEN / VAE / unknown buckets.

```python
nemo_automodel.components.models.bagel.state_dict_adapter.load_bagel_checkpoint_state_dict(
    checkpoint_dir: str | pathlib.Path,
    stage: typing.Any = 'stage1',
    strict: bool = True,
    config: typing.Any = None
) -> dict[str, 'torch.Tensor']
```

Load a BAGEL HF checkpoint directory into a NeMo-layout state dict.

Reads `ema.safetensors` from `checkpoint_dir` and passes the result
through :class:`BagelStateDictAdapter`. Stage 2 VAE weights live in
`ae.safetensors` but are loaded separately by the recipe because the VAE
is not an `nn.Module` child of `BagelForUnifiedMultimodal`.

**Parameters:**

Path to a directory containing `ema.safetensors`.

`"stage1"` or `"stage2"`.

Forwarded to `from_hf`; raise on unmatched keys.

Optional `BagelConfig` for the adapter's log context.

**Returns:** `dict[str, 'torch.Tensor']`

A flat `&#123;key: Tensor&#125;` dict in NeMo layout, ready for

```python
nemo_automodel.components.models.bagel.state_dict_adapter.GEN_PATTERNS = ['^language_model\\.model\\.layers\\.\\d+\\.self_attn\\.q_proj_moe_gen\\.(weight...
```

```python
nemo_automodel.components.models.bagel.state_dict_adapter.SHARED_PATTERNS: list[str] = []
```

```python
nemo_automodel.components.models.bagel.state_dict_adapter.UND_PATTERNS = ['^language_model\\.model\\.embed_tokens\\.weight$', '^language_model\\.model\\....
```

```python
nemo_automodel.components.models.bagel.state_dict_adapter.VAE_PATTERNS = ['^encoder\\.', '^decoder\\.']
```

```python
nemo_automodel.components.models.bagel.state_dict_adapter._GEN_RES = _compile(GEN_PATTERNS)
```

```python
nemo_automodel.components.models.bagel.state_dict_adapter._UND_RES = _compile(UND_PATTERNS)
```

```python
nemo_automodel.components.models.bagel.state_dict_adapter._VAE_RES = _compile(VAE_PATTERNS)
```

```python
nemo_automodel.components.models.bagel.state_dict_adapter.logger = logging.getLogger(__name__)
```