nemo_automodel.components.models.bagel.state_dict_adapter
nemo_automodel.components.models.bagel.state_dict_adapter
State-dict adapter for BAGEL HF checkpoints.
BAGEL-7B-MoT ships two on-disk files with complementary key namespaces:
ema.safetensors— everything under the top-level module tree, both the UND (understanding) path and the GEN (*_moe_gen) Mixture-of- Transformers siblings plus the flow-matching scaffolding (time_embedder,vae2llm,llm2vae,latent_pos_embed).ae.safetensors— the VAE encoder/decoder weights. These live in a separate file and are loaded by the Stage 2 recipe via upstreamload_ae; they are not parameters ofBagelForUnifiedMultimodal.
Special cases:
vit_pos_embed.pos_embedis present in the checkpoint. Upstream BAGEL declares it as a frozennn.Parameter(requires_grad=False)(sinusoidal 2D position embedding) — so it serializes like any other parameter. It is classified as UND since it feeds the understanding-side connector.- The released BAGEL-7B-MoT checkpoint stores
vit_model.vision_model.embeddings.patch_embedding.weightin the post-conversion linear layout(out_channels, in_channels * P * P). AM swaps the freshConv2dmodule toLinearbefore loading that checkpoint so the tensor shape matches directly. embed_tokens/lm_head/ the finalnormare UND-side tensors that are logically read by the GEN path (GEN tokens use text embeddings and the shared LM head). No physical tensor sharing is required in the checkpoint — the module tree uses references, not copies.
Stage 1 loads the UND subset only (UND_PATTERNS). Stage 2 additionally
loads the GEN subset (GEN_PATTERNS). VAE keys are recognized only so that
accidentally merged checkpoints can be reported cleanly instead of being
treated as unknown model weights.
Module Contents
Classes
Functions
Data
API
Bases: StateDictAdapter
HF <-> NeMo state-dict converter for BAGEL.
Stage 1 returns only the UND subset from from_hf. Stage 2 additionally
returns the GEN subset. The VAE remains outside this module and is loaded
separately by the training recipe.
Because BagelForUnifiedMultimodal wraps the upstream BAGEL module
tree under self.model, HF checkpoint keys are unrooted
(language_model...) while native AM keys are rooted
(model.language_model...). The adapter filters upstream UND/GEN keys
and handles this root mapping.
Parameters:
BagelConfig (or None; currently only used for log
context — no shape sanity checks yet).
Default stage used when from_hf is called without an
explicit stage kwarg. Accepts "stage1" / "stage2" or
1 / 2.
Return [(hf_fqn, tensor)] for a single NeMo tensor.
Identity for BAGEL checkpoint keys.
Convert an HF-layout BAGEL state dict to the NeMo module-tree layout.
Parameters:
Flat dict loaded from ema.safetensors. If a
caller accidentally passes merged VAE keys, they are recognized
and excluded from the model state dict.
Unused for BAGEL (no expert parallelism yet); kept for base-class signature compatibility.
"stage1" keeps only UND keys. "stage2" keeps UND
and GEN keys. Defaults to self.stage.
If True (default), raise KeyError on any key that
matches no known pattern. If False, log and drop.
Returns: dict[str, 'torch.Tensor']
Filtered state dict in NeMo layout, ready to feed into
Raises:
KeyError: Whenstrict=Trueand one or more input keys match no UND/GEN/VAE pattern.
Convert a NeMo-layout state dict back to the HF BAGEL layout.
The VAE is not part of this module tree and should be saved/loaded separately.
Normalize stage to one of "stage1" or "stage2".
Accepts either strings ("stage1" / "stage2") or integers
(1 / 2) for backward compatibility with earlier callers.
Partition a flat checkpoint dict into UND / GEN / VAE / unknown buckets.
Load a BAGEL HF checkpoint directory into a NeMo-layout state dict.
Reads ema.safetensors from checkpoint_dir and passes the result
through :class:BagelStateDictAdapter. Stage 2 VAE weights live in
ae.safetensors but are loaded separately by the recipe because the VAE
is not an nn.Module child of BagelForUnifiedMultimodal.
Parameters:
Path to a directory containing ema.safetensors.
"stage1" or "stage2".
Forwarded to from_hf; raise on unmatched keys.
Optional BagelConfig for the adapter’s log context.
Returns: dict[str, 'torch.Tensor']
A flat {key: Tensor} dict in NeMo layout, ready for