nemo_automodel.components.models.bagel.model
nemo_automodel.components.models.bagel.model
Top-level BAGEL model.
The module supports Stage 1 understanding-only CE and Stage 2 joint
understanding and visual generation with flow-matching MSE. The VAE itself
stays outside this module tree; the recipe passes VAE-encoded latents to
forward.
Top-level module is nn.Module (not PreTrainedModel) + mixed-in
HFCheckpointingMixin — matches the llava-onevision pattern and avoids the
FSDP double-root issue that bites us with PreTrainedModel-derived roots.
Module Contents
Classes
Functions
Data
API
Bases: HFCheckpointingMixin, Module
BAGEL mixed-modal LLM wrapper for understanding and optional generation.
visual_gen=False gives the Stage 1 understanding-only path. Stage 2
sets visual_gen=True and uses the MoT *_moe_gen parameter siblings,
VAE latents prepared by the recipe, and the flow-matching MSE head.
Run the BAGEL mixed-modal forward.
Stage 1 (visual_gen=False) skips the flow-matching branch and MSE
computation; Stage 2 activates both. ce_loss_weights is accepted
for data-pipeline compatibility but not consumed here - CE is returned
per-token (reduction="none") and the trainer may apply weights
downstream.
Stage 2 inputs (padded_latent, patchified_vae_latent_shapes,
packed_latent_position_ids, packed_vae_token_indexes,
packed_timesteps, mse_loss_indexes) are produced by the
BAGEL collator when a pack contains t2i/edit samples. padded_latent
is the VAE-encoded latent tensor; recipe must call vae_model.encode
on the raw padded_images before forward; this module does not own
the VAE.
Returns: Dict[str, Optional[torch.Tensor]]
dict(ce=Tensor|None, mse=Tensor|None) - both can be None
Load a BAGEL-7B-MoT checkpoint directory into this class.
Reads config.json via :meth:BagelConfig.from_pretrained, constructs
an empty model, and then loads ema.safetensors filtered by
:class:BagelStateDictAdapter. Stage 2 VAE weights are loaded by the
recipe because the VAE is not owned by this module tree.
Parameters:
Directory containing the HF-layout BAGEL checkpoint.
1 (UND only) or 2 (UND + GEN). Strings
"stage1" / "stage2" are also accepted.
If True, raise on state-dict keys that don’t match the
adapter patterns. Defaults to False for compatibility with
checkpoint sidecar files.
Forwarded to BagelConfig.from_pretrained.
Returns: 'BagelForUnifiedMultimodal'
A fully-initialized BagelForUnifiedMultimodal with weights
Initialize BAGEL weights after AM materializes a from_config model.
BagelForUnifiedMultimodal is an nn.Module root, not a HF
PreTrainedModel root. AM’s meta-device from_config path
materializes parameters after sharding and then calls this method.
Delegate Qwen/SigLIP subtrees to their HF-style initializers, then
initialize BAGEL-only connector and generation modules.
Return True if this custom class supports config.
Bases: Module
Plain container for the three BAGEL submodules.
Attribute names (language_model, vit_model, connector,
vit_pos_embed) match the checkpoint layout so the state-dict adapter
maps identity. There’s no forward logic here - this class exists so that
FSDP / state-dict tooling sees the expected tree structure without being
confused by the HFCheckpointingMixin root.
When config.visual_gen=True (Stage 2), we additionally attach the
generation-side siblings (time_embedder, vae2llm, llm2vae,
latent_pos_embed) so the flow-matching head is ready to run. The VAE
model itself is NOT owned here; the recipe keeps it separate
(frozen, inference-only) and passes already-encoded latents into
BagelForUnifiedMultimodal.forward.
Swap SigLIP patch embedding to Linear for BAGEL packed pixel inputs.
Apply BAGEL stage/checkpoint config fixes before module construction.
AutoModel instantiates custom models as model_cls(config) and lets the
common checkpointer load weights later. BAGEL therefore needs the same
stage-dependent config mutations that its direct from_pretrained path
used to do before BagelModel is built.
Normalize a BAGEL training stage value to 1 or 2.