nemo_automodel.components.models.bagel.state_dict_adapter#

State-dict adapter for BAGEL HF checkpoints.

BAGEL-7B-MoT ships two on-disk files with complementary key namespaces:

  • ema.safetensors β€” everything under the top-level module tree, both the UND (understanding) path and the GEN (*_moe_gen) Mixture-of- Transformers siblings plus the flow-matching scaffolding (time_embedder, vae2llm, llm2vae, latent_pos_embed).

  • ae.safetensors β€” the VAE encoder/decoder weights. These live in a separate file and are loaded by the Stage 2 recipe via upstream load_ae; they are not parameters of BagelForUnifiedMultimodal.

Special cases:

  • vit_pos_embed.pos_embed is present in the checkpoint. Upstream BAGEL declares it as a frozen nn.Parameter(requires_grad=False) (sinusoidal 2D position embedding) β€” so it serializes like any other parameter. It is classified as UND since it feeds the understanding-side connector.

  • The released BAGEL-7B-MoT checkpoint stores vit_model.vision_model.embeddings.patch_embedding.weight in the post-conversion linear layout (out_channels, in_channels * P * P). AM swaps the fresh Conv2d module to Linear before loading that checkpoint so the tensor shape matches directly.

  • embed_tokens / lm_head / the final norm are UND-side tensors that are logically read by the GEN path (GEN tokens use text embeddings and the shared LM head). No physical tensor sharing is required in the checkpoint β€” the module tree uses references, not copies.

Stage 1 loads the UND subset only (UND_PATTERNS). Stage 2 additionally loads the GEN subset (GEN_PATTERNS). VAE keys are recognized only so that accidentally merged checkpoints can be reported cleanly instead of being treated as unknown model weights.

Module Contents#

Classes#

BagelStateDictAdapter

HF <-> NeMo state-dict converter for BAGEL.

Functions#

_compile

_matches_any

_normalize_stage

Normalize stage to one of "stage1" or "stage2".

_partition

Partition a flat checkpoint dict into UND / GEN / VAE / unknown buckets.

load_bagel_checkpoint_state_dict

Load a BAGEL HF checkpoint directory into a NeMo-layout state dict.

Data#

API#

nemo_automodel.components.models.bagel.state_dict_adapter.logger#

β€˜getLogger(…)’

nemo_automodel.components.models.bagel.state_dict_adapter.UND_PATTERNS#

[β€˜^language_model.model.embed_tokens.weight\(', '^language_model\\.model\\.norm\\.weight\)’, β€˜^l…

nemo_automodel.components.models.bagel.state_dict_adapter.GEN_PATTERNS#

[β€˜^language_model.model.layers.\d+.self_attn.q_proj_moe_gen.(weight|bias)$’, β€˜^language…

nemo_automodel.components.models.bagel.state_dict_adapter.VAE_PATTERNS#

[β€˜^encoder.’, β€˜^decoder.’]

nemo_automodel.components.models.bagel.state_dict_adapter.SHARED_PATTERNS: list[str]#

[]

nemo_automodel.components.models.bagel.state_dict_adapter._compile(patterns: list[str]) list[re.Pattern]#
nemo_automodel.components.models.bagel.state_dict_adapter._UND_RES#

β€˜_compile(…)’

nemo_automodel.components.models.bagel.state_dict_adapter._GEN_RES#

β€˜_compile(…)’

nemo_automodel.components.models.bagel.state_dict_adapter._VAE_RES#

β€˜_compile(…)’

nemo_automodel.components.models.bagel.state_dict_adapter._matches_any(key: str, patterns: list[re.Pattern]) bool#
nemo_automodel.components.models.bagel.state_dict_adapter._normalize_stage(stage: Any) str#

Normalize stage to one of "stage1" or "stage2".

Accepts either strings ("stage1" / "stage2") or integers (1 / 2) for backward compatibility with earlier callers.

nemo_automodel.components.models.bagel.state_dict_adapter._partition(
state_dict: dict[str, Any],
) dict[str, dict[str, Any]]#

Partition a flat checkpoint dict into UND / GEN / VAE / unknown buckets.

class nemo_automodel.components.models.bagel.state_dict_adapter.BagelStateDictAdapter(
config: Any = None,
*,
stage: Any = 'stage1',
)#

Bases: nemo_automodel.components.checkpoint.state_dict_adapter.StateDictAdapter

HF <-> NeMo state-dict converter for BAGEL.

Stage 1 returns only the UND subset from from_hf. Stage 2 additionally returns the GEN subset. The VAE remains outside this module and is loaded separately by the training recipe.

Because BagelForUnifiedMultimodal wraps the upstream BAGEL module tree under self.model, HF checkpoint keys are unrooted (language_model...) while native AM keys are rooted (model.language_model...). The adapter filters upstream UND/GEN keys and handles this root mapping.

Parameters:
  • config – BagelConfig (or None; currently only used for log context β€” no shape sanity checks yet).

  • stage – Default stage used when from_hf is called without an explicit stage kwarg. Accepts "stage1" / "stage2" or 1 / 2.

Initialization

_hf_to_nemo_key(hf_key: str) str#
_nemo_to_hf_key(nemo_key: str) str#
_strip_nemo_root(key: str) str#
from_hf(
hf_state_dict: dict[str, torch.Tensor],
device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
*,
stage: Optional[Any] = None,
strict: bool = True,
**kwargs: Any,
) dict[str, torch.Tensor]#

Convert an HF-layout BAGEL state dict to the NeMo module-tree layout.

Parameters:
  • hf_state_dict – Flat dict loaded from ema.safetensors. If a caller accidentally passes merged VAE keys, they are recognized and excluded from the model state dict.

  • device_mesh – Unused for BAGEL (no expert parallelism yet); kept for base-class signature compatibility.

  • stage – "stage1" keeps only UND keys. "stage2" keeps UND and GEN keys. Defaults to self.stage.

  • strict – If True (default), raise KeyError on any key that matches no known pattern. If False, log and drop.

Returns:

Filtered state dict in NeMo layout, ready to feed into BagelForUnifiedMultimodal.load_state_dict(...).

Raises:

KeyError – When strict=True and one or more input keys match no UND/GEN/VAE pattern.

to_hf(
state_dict: dict[str, torch.Tensor],
**kwargs: Any,
) dict[str, torch.Tensor]#

Convert a NeMo-layout state dict back to the HF BAGEL layout.

The VAE is not part of this module tree and should be saved/loaded separately.

convert_single_tensor_to_hf(
fqn: str,
tensor: torch.Tensor,
**kwargs: Any,
) list[tuple[str, torch.Tensor]]#

Return [(hf_fqn, tensor)] for a single NeMo tensor.

Identity for BAGEL checkpoint keys.

nemo_automodel.components.models.bagel.state_dict_adapter.load_bagel_checkpoint_state_dict(
checkpoint_dir: str | pathlib.Path,
*,
stage: Any = 'stage1',
strict: bool = True,
config: Any = None,
) dict[str, torch.Tensor]#

Load a BAGEL HF checkpoint directory into a NeMo-layout state dict.

Reads ema.safetensors from checkpoint_dir and passes the result through :class:BagelStateDictAdapter. Stage 2 VAE weights live in ae.safetensors but are loaded separately by the recipe because the VAE is not an nn.Module child of BagelForUnifiedMultimodal.

Parameters:
  • checkpoint_dir – Path to a directory containing ema.safetensors.

  • stage – "stage1" or "stage2".

  • strict – Forwarded to from_hf; raise on unmatched keys.

  • config – Optional BagelConfig for the adapter’s log context.

Returns:

A flat {key: Tensor} dict in NeMo layout, ready for BagelForUnifiedMultimodal.load_state_dict(...).