nemo_automodel.components.models.bagel.state_dict_adapter

View as Markdown

State-dict adapter for BAGEL HF checkpoints.

BAGEL-7B-MoT ships two on-disk files with complementary key namespaces:

  • ema.safetensors — everything under the top-level module tree, both the UND (understanding) path and the GEN (*_moe_gen) Mixture-of- Transformers siblings plus the flow-matching scaffolding (time_embedder, vae2llm, llm2vae, latent_pos_embed).
  • ae.safetensors — the VAE encoder/decoder weights. These live in a separate file and are loaded by the Stage 2 recipe via upstream load_ae; they are not parameters of BagelForUnifiedMultimodal.

Special cases:

  • vit_pos_embed.pos_embed is present in the checkpoint. Upstream BAGEL declares it as a frozen nn.Parameter(requires_grad=False) (sinusoidal 2D position embedding) — so it serializes like any other parameter. It is classified as UND since it feeds the understanding-side connector.
  • The released BAGEL-7B-MoT checkpoint stores vit_model.vision_model.embeddings.patch_embedding.weight in the post-conversion linear layout (out_channels, in_channels * P * P). AM swaps the fresh Conv2d module to Linear before loading that checkpoint so the tensor shape matches directly.
  • embed_tokens / lm_head / the final norm are UND-side tensors that are logically read by the GEN path (GEN tokens use text embeddings and the shared LM head). No physical tensor sharing is required in the checkpoint — the module tree uses references, not copies.

Stage 1 loads the UND subset only (UND_PATTERNS). Stage 2 additionally loads the GEN subset (GEN_PATTERNS). VAE keys are recognized only so that accidentally merged checkpoints can be reported cleanly instead of being treated as unknown model weights.

Module Contents

Classes

NameDescription
BagelStateDictAdapterHF <-> NeMo state-dict converter for BAGEL.

Functions

NameDescription
_compile-
_matches_any-
_normalize_stageNormalize stage to one of "stage1" or "stage2".
_partitionPartition a flat checkpoint dict into UND / GEN / VAE / unknown buckets.
load_bagel_checkpoint_state_dictLoad a BAGEL HF checkpoint directory into a NeMo-layout state dict.

Data

GEN_PATTERNS

SHARED_PATTERNS

UND_PATTERNS

VAE_PATTERNS

_GEN_RES

_UND_RES

_VAE_RES

logger

API

class nemo_automodel.components.models.bagel.state_dict_adapter.BagelStateDictAdapter(
config: typing.Any = None,
stage: typing.Any = 'stage1'
)

Bases: StateDictAdapter

HF <-> NeMo state-dict converter for BAGEL.

Stage 1 returns only the UND subset from from_hf. Stage 2 additionally returns the GEN subset. The VAE remains outside this module and is loaded separately by the training recipe.

Because BagelForUnifiedMultimodal wraps the upstream BAGEL module tree under self.model, HF checkpoint keys are unrooted (language_model...) while native AM keys are rooted (model.language_model...). The adapter filters upstream UND/GEN keys and handles this root mapping.

Parameters:

config
AnyDefaults to None

BagelConfig (or None; currently only used for log context — no shape sanity checks yet).

stage
AnyDefaults to 'stage1'

Default stage used when from_hf is called without an explicit stage kwarg. Accepts "stage1" / "stage2" or 1 / 2.

stage
= _normalize_stage(stage)
nemo_automodel.components.models.bagel.state_dict_adapter.BagelStateDictAdapter._hf_to_nemo_key(
hf_key: str
) -> str
nemo_automodel.components.models.bagel.state_dict_adapter.BagelStateDictAdapter._nemo_to_hf_key(
nemo_key: str
) -> str
nemo_automodel.components.models.bagel.state_dict_adapter.BagelStateDictAdapter._strip_nemo_root(
key: str
) -> str
nemo_automodel.components.models.bagel.state_dict_adapter.BagelStateDictAdapter.convert_single_tensor_to_hf(
fqn: str,
tensor: 'torch.Tensor',
kwargs: typing.Any = {}
) -> list[tuple[str, 'torch.Tensor']]

Return [(hf_fqn, tensor)] for a single NeMo tensor.

Identity for BAGEL checkpoint keys.

nemo_automodel.components.models.bagel.state_dict_adapter.BagelStateDictAdapter.from_hf(
hf_state_dict: dict[str, 'torch.Tensor'],
device_mesh: typing.Optional['DeviceMesh'] = None,
stage: typing.Optional[typing.Any] = None,
strict: bool = True,
kwargs: typing.Any = {}
) -> dict[str, 'torch.Tensor']

Convert an HF-layout BAGEL state dict to the NeMo module-tree layout.

Parameters:

hf_state_dict
dict[str, 'torch.Tensor']

Flat dict loaded from ema.safetensors. If a caller accidentally passes merged VAE keys, they are recognized and excluded from the model state dict.

device_mesh
Optional['DeviceMesh']Defaults to None

Unused for BAGEL (no expert parallelism yet); kept for base-class signature compatibility.

stage
Optional[Any]Defaults to None

"stage1" keeps only UND keys. "stage2" keeps UND and GEN keys. Defaults to self.stage.

strict
boolDefaults to True

If True (default), raise KeyError on any key that matches no known pattern. If False, log and drop.

Returns: dict[str, 'torch.Tensor']

Filtered state dict in NeMo layout, ready to feed into

Raises:

  • KeyError: When strict=True and one or more input keys match no UND/GEN/VAE pattern.
nemo_automodel.components.models.bagel.state_dict_adapter.BagelStateDictAdapter.to_hf(
state_dict: dict[str, 'torch.Tensor'],
kwargs: typing.Any = {}
) -> dict[str, 'torch.Tensor']

Convert a NeMo-layout state dict back to the HF BAGEL layout.

The VAE is not part of this module tree and should be saved/loaded separately.

nemo_automodel.components.models.bagel.state_dict_adapter._compile(
patterns: list[str]
) -> list[re.Pattern]
nemo_automodel.components.models.bagel.state_dict_adapter._matches_any(
key: str,
patterns: list[re.Pattern]
) -> bool
nemo_automodel.components.models.bagel.state_dict_adapter._normalize_stage(
stage: typing.Any
) -> str

Normalize stage to one of "stage1" or "stage2".

Accepts either strings ("stage1" / "stage2") or integers (1 / 2) for backward compatibility with earlier callers.

nemo_automodel.components.models.bagel.state_dict_adapter._partition(
state_dict: dict[str, typing.Any]
) -> dict[str, dict[str, typing.Any]]

Partition a flat checkpoint dict into UND / GEN / VAE / unknown buckets.

nemo_automodel.components.models.bagel.state_dict_adapter.load_bagel_checkpoint_state_dict(
checkpoint_dir: str | pathlib.Path,
stage: typing.Any = 'stage1',
strict: bool = True,
config: typing.Any = None
) -> dict[str, 'torch.Tensor']

Load a BAGEL HF checkpoint directory into a NeMo-layout state dict.

Reads ema.safetensors from checkpoint_dir and passes the result through :class:BagelStateDictAdapter. Stage 2 VAE weights live in ae.safetensors but are loaded separately by the recipe because the VAE is not an nn.Module child of BagelForUnifiedMultimodal.

Parameters:

checkpoint_dir
str | pathlib.Path

Path to a directory containing ema.safetensors.

stage
AnyDefaults to 'stage1'

"stage1" or "stage2".

strict
boolDefaults to True

Forwarded to from_hf; raise on unmatched keys.

config
AnyDefaults to None

Optional BagelConfig for the adapter’s log context.

Returns: dict[str, 'torch.Tensor']

A flat &#123;key: Tensor&#125; dict in NeMo layout, ready for

nemo_automodel.components.models.bagel.state_dict_adapter.GEN_PATTERNS = ['^language_model\\.model\\.layers\\.\\d+\\.self_attn\\.q_proj_moe_gen\\.(weight...
nemo_automodel.components.models.bagel.state_dict_adapter.SHARED_PATTERNS: list[str] = []
nemo_automodel.components.models.bagel.state_dict_adapter.UND_PATTERNS = ['^language_model\\.model\\.embed_tokens\\.weight$', '^language_model\\.model\\....
nemo_automodel.components.models.bagel.state_dict_adapter.VAE_PATTERNS = ['^encoder\\.', '^decoder\\.']
nemo_automodel.components.models.bagel.state_dict_adapter._GEN_RES = _compile(GEN_PATTERNS)
nemo_automodel.components.models.bagel.state_dict_adapter._UND_RES = _compile(UND_PATTERNS)
nemo_automodel.components.models.bagel.state_dict_adapter._VAE_RES = _compile(VAE_PATTERNS)
nemo_automodel.components.models.bagel.state_dict_adapter.logger = logging.getLogger(__name__)