nemo_automodel.components.models.bagel.state_dict_adapter#
State-dict adapter for BAGEL HF checkpoints.
BAGEL-7B-MoT ships two on-disk files with complementary key namespaces:
ema.safetensorsβ everything under the top-level module tree, both the UND (understanding) path and the GEN (*_moe_gen) Mixture-of- Transformers siblings plus the flow-matching scaffolding (time_embedder,vae2llm,llm2vae,latent_pos_embed).ae.safetensorsβ the VAE encoder/decoder weights. These live in a separate file and are loaded by the Stage 2 recipe via upstreamload_ae; they are not parameters ofBagelForUnifiedMultimodal.
Special cases:
vit_pos_embed.pos_embedis present in the checkpoint. Upstream BAGEL declares it as a frozennn.Parameter(requires_grad=False)(sinusoidal 2D position embedding) β so it serializes like any other parameter. It is classified as UND since it feeds the understanding-side connector.The released BAGEL-7B-MoT checkpoint stores
vit_model.vision_model.embeddings.patch_embedding.weightin the post-conversion linear layout(out_channels, in_channels * P * P). AM swaps the freshConv2dmodule toLinearbefore loading that checkpoint so the tensor shape matches directly.embed_tokens/lm_head/ the finalnormare UND-side tensors that are logically read by the GEN path (GEN tokens use text embeddings and the shared LM head). No physical tensor sharing is required in the checkpoint β the module tree uses references, not copies.
Stage 1 loads the UND subset only (UND_PATTERNS). Stage 2 additionally
loads the GEN subset (GEN_PATTERNS). VAE keys are recognized only so that
accidentally merged checkpoints can be reported cleanly instead of being
treated as unknown model weights.
Module Contents#
Classes#
HF <-> NeMo state-dict converter for BAGEL. |
Functions#
Normalize |
|
Partition a flat checkpoint dict into UND / GEN / VAE / unknown buckets. |
|
Load a BAGEL HF checkpoint directory into a NeMo-layout state dict. |
Data#
API#
- nemo_automodel.components.models.bagel.state_dict_adapter.logger#
βgetLogger(β¦)β
- nemo_automodel.components.models.bagel.state_dict_adapter.UND_PATTERNS#
[β^language_model.model.embed_tokens.weight\(', '^language_model\\.model\\.norm\\.weight\)β, β^lβ¦
- nemo_automodel.components.models.bagel.state_dict_adapter.GEN_PATTERNS#
[β^language_model.model.layers.\d+.self_attn.q_proj_moe_gen.(weight|bias)$β, β^languageβ¦
- nemo_automodel.components.models.bagel.state_dict_adapter.VAE_PATTERNS#
[β^encoder.β, β^decoder.β]
- nemo_automodel.components.models.bagel.state_dict_adapter.SHARED_PATTERNS: list[str]#
[]
- nemo_automodel.components.models.bagel.state_dict_adapter._compile(patterns: list[str]) list[re.Pattern]#
- nemo_automodel.components.models.bagel.state_dict_adapter._UND_RES#
β_compile(β¦)β
- nemo_automodel.components.models.bagel.state_dict_adapter._GEN_RES#
β_compile(β¦)β
- nemo_automodel.components.models.bagel.state_dict_adapter._VAE_RES#
β_compile(β¦)β
- nemo_automodel.components.models.bagel.state_dict_adapter._matches_any(key: str, patterns: list[re.Pattern]) bool#
- nemo_automodel.components.models.bagel.state_dict_adapter._normalize_stage(stage: Any) str#
Normalize
stageto one of"stage1"or"stage2".Accepts either strings (
"stage1"/"stage2") or integers (1/2) for backward compatibility with earlier callers.
- nemo_automodel.components.models.bagel.state_dict_adapter._partition(
- state_dict: dict[str, Any],
Partition a flat checkpoint dict into UND / GEN / VAE / unknown buckets.
- class nemo_automodel.components.models.bagel.state_dict_adapter.BagelStateDictAdapter(
- config: Any = None,
- *,
- stage: Any = 'stage1',
Bases:
nemo_automodel.components.checkpoint.state_dict_adapter.StateDictAdapterHF <-> NeMo state-dict converter for BAGEL.
Stage 1 returns only the UND subset from
from_hf. Stage 2 additionally returns the GEN subset. The VAE remains outside this module and is loaded separately by the training recipe.Because
BagelForUnifiedMultimodalwraps the upstream BAGEL module tree underself.model, HF checkpoint keys are unrooted (language_model...) while native AM keys are rooted (model.language_model...). The adapter filters upstream UND/GEN keys and handles this root mapping.- Parameters:
config β
BagelConfig(orNone; currently only used for log context β no shape sanity checks yet).stage β Default stage used when
from_hfis called without an explicitstagekwarg. Accepts"stage1"/"stage2"or1/2.
Initialization
- _hf_to_nemo_key(hf_key: str) str#
- _nemo_to_hf_key(nemo_key: str) str#
- _strip_nemo_root(key: str) str#
- from_hf(
- hf_state_dict: dict[str, torch.Tensor],
- device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
- *,
- stage: Optional[Any] = None,
- strict: bool = True,
- **kwargs: Any,
Convert an HF-layout BAGEL state dict to the NeMo module-tree layout.
- Parameters:
hf_state_dict β Flat dict loaded from
ema.safetensors. If a caller accidentally passes merged VAE keys, they are recognized and excluded from the model state dict.device_mesh β Unused for BAGEL (no expert parallelism yet); kept for base-class signature compatibility.
stage β
"stage1"keeps only UND keys."stage2"keeps UND and GEN keys. Defaults toself.stage.strict β If
True(default), raiseKeyErroron any key that matches no known pattern. IfFalse, log and drop.
- Returns:
Filtered state dict in NeMo layout, ready to feed into
BagelForUnifiedMultimodal.load_state_dict(...).- Raises:
KeyError β When
strict=Trueand one or more input keys match no UND/GEN/VAE pattern.
- to_hf(
- state_dict: dict[str, torch.Tensor],
- **kwargs: Any,
Convert a NeMo-layout state dict back to the HF BAGEL layout.
The VAE is not part of this module tree and should be saved/loaded separately.
- convert_single_tensor_to_hf(
- fqn: str,
- tensor: torch.Tensor,
- **kwargs: Any,
Return
[(hf_fqn, tensor)]for a single NeMo tensor.Identity for BAGEL checkpoint keys.
- nemo_automodel.components.models.bagel.state_dict_adapter.load_bagel_checkpoint_state_dict(
- checkpoint_dir: str | pathlib.Path,
- *,
- stage: Any = 'stage1',
- strict: bool = True,
- config: Any = None,
Load a BAGEL HF checkpoint directory into a NeMo-layout state dict.
Reads
ema.safetensorsfromcheckpoint_dirand passes the result through :class:BagelStateDictAdapter. Stage 2 VAE weights live inae.safetensorsbut are loaded separately by the recipe because the VAE is not annn.Modulechild ofBagelForUnifiedMultimodal.- Parameters:
checkpoint_dir β Path to a directory containing
ema.safetensors.stage β
"stage1"or"stage2".strict β Forwarded to
from_hf; raise on unmatched keys.config β Optional
BagelConfigfor the adapterβs log context.
- Returns:
A flat
{key: Tensor}dict in NeMo layout, ready forBagelForUnifiedMultimodal.load_state_dict(...).