nemo_automodel.components.models.bagel.model#
Top-level BAGEL model.
The module supports Stage 1 understanding-only CE and Stage 2 joint
understanding and visual generation with flow-matching MSE. The VAE itself
stays outside this module tree; the recipe passes VAE-encoded latents to
forward.
Top-level module is nn.Module (not PreTrainedModel) + mixed-in
HFCheckpointingMixin — matches the llava-onevision pattern and avoids the
FSDP double-root issue that bites us with PreTrainedModel-derived roots.
Module Contents#
Classes#
Plain container for the three BAGEL submodules. |
|
BAGEL mixed-modal LLM wrapper for understanding and optional generation. |
Functions#
Normalize a BAGEL training stage value to |
|
Apply BAGEL stage/checkpoint config fixes before module construction. |
|
Swap SigLIP patch embedding to Linear for BAGEL packed pixel inputs. |
Data#
API#
- nemo_automodel.components.models.bagel.model.logger#
‘getLogger(…)’
- nemo_automodel.components.models.bagel.model._stage_to_int(stage: Union[int, str]) int#
Normalize a BAGEL training stage value to
1or2.
- nemo_automodel.components.models.bagel.model._prepare_config_for_stage( ) None#
Apply BAGEL stage/checkpoint config fixes before module construction.
AutoModel instantiates custom models as
model_cls(config)and lets the common checkpointer load weights later. BAGEL therefore needs the same stage-dependent config mutations that its directfrom_pretrainedpath used to do beforeBagelModelis built.
- nemo_automodel.components.models.bagel.model._convert_patch_embedding_for_packed_vit( ) None#
Swap SigLIP patch embedding to Linear for BAGEL packed pixel inputs.
- class nemo_automodel.components.models.bagel.model.BagelModel( )#
Bases:
torch.nn.ModulePlain container for the three BAGEL submodules.
Attribute names (
language_model,vit_model,connector,vit_pos_embed) match the checkpoint layout so the state-dict adapter maps identity. There’s no forward logic here - this class exists so that FSDP / state-dict tooling sees the expected tree structure without being confused by theHFCheckpointingMixinroot.When
config.visual_gen=True(Stage 2), we additionally attach the generation-side siblings (time_embedder,vae2llm,llm2vae,latent_pos_embed) so the flow-matching head is ready to run. The VAE model itself is NOT owned here; the recipe keeps it separate (frozen, inference-only) and passes already-encoded latents intoBagelForUnifiedMultimodal.forward.Initialization
- class nemo_automodel.components.models.bagel.model.BagelForUnifiedMultimodal( )#
Bases:
nemo_automodel.components.models.common.hf_checkpointing_mixin.HFCheckpointingMixin,torch.nn.ModuleBAGEL mixed-modal LLM wrapper for understanding and optional generation.
visual_gen=Falsegives the Stage 1 understanding-only path. Stage 2 setsvisual_gen=Trueand uses the MoT*_moe_genparameter siblings, VAE latents prepared by the recipe, and the flow-matching MSE head.Initialization
- config_class#
None
- class ModelCapabilities#
Declared parallelism capabilities for this model class.
- supports_tp: bool#
False
- supports_cp: bool#
False
- supports_pp: bool#
False
- supports_ep: bool#
False
- initialize_weights() None#
Initialize BAGEL weights after AM materializes a
from_configmodel.BagelForUnifiedMultimodalis annn.Moduleroot, not a HFPreTrainedModelroot. AM’s meta-devicefrom_configpath materializes parameters after sharding and then calls this method. Delegate Qwen/SigLIP subtrees to their HF-style initializers, then initialize BAGEL-only connector and generation modules.
- classmethod from_pretrained(
- pretrained_model_name_or_path: Union[str, os.PathLike],
- *,
- stage: Union[int, str] = 1,
- strict: bool = False,
- **kwargs: Any,
Load a BAGEL-7B-MoT checkpoint directory into this class.
Reads
config.jsonvia :meth:BagelConfig.from_pretrained, constructs an empty model, and then loadsema.safetensorsfiltered by- Class:
BagelStateDictAdapter. Stage 2 VAE weights are loaded by the recipe because the VAE is not owned by this module tree.- Parameters:
pretrained_model_name_or_path – Directory containing the HF-layout BAGEL checkpoint.
stage –
1(UND only) or2(UND + GEN). Strings"stage1"/"stage2"are also accepted.strict – If
True, raise on state-dict keys that don’t match the adapter patterns. Defaults toFalsefor compatibility with checkpoint sidecar files.**kwargs – Forwarded to
BagelConfig.from_pretrained.
- Returns:
A fully-initialized
BagelForUnifiedMultimodalwith weights populated from disk. For Stage 1,visual_genis forced off on the loaded config so the MoT gen-side path is left untouched.
- classmethod supports_config(config: Any) bool#
Return
Trueif this custom class supportsconfig.
- get_input_embeddings() torch.nn.Module#
- get_output_embeddings() torch.nn.Module#
- forward(
- sequence_length: int,
- packed_text_ids: torch.LongTensor,
- packed_text_indexes: torch.LongTensor,
- sample_lens: List[int],
- packed_position_ids: torch.LongTensor,
- nested_attention_masks: Optional[List[torch.Tensor]] = None,
- split_lens: Optional[List[int]] = None,
- attn_modes: Optional[List[str]] = None,
- packed_vit_tokens: Optional[torch.Tensor] = None,
- packed_vit_token_indexes: Optional[torch.LongTensor] = None,
- packed_vit_position_ids: Optional[torch.LongTensor] = None,
- vit_token_seqlens: Optional[torch.Tensor] = None,
- padded_latent: Optional[torch.Tensor] = None,
- patchified_vae_latent_shapes: Optional[List[Tuple[int, int]]] = None,
- packed_latent_position_ids: Optional[torch.LongTensor] = None,
- packed_vae_token_indexes: Optional[torch.LongTensor] = None,
- packed_timesteps: Optional[torch.Tensor] = None,
- mse_loss_indexes: Optional[torch.Tensor] = None,
- ce_loss_indexes: Optional[torch.Tensor] = None,
- packed_label_ids: Optional[torch.Tensor] = None,
- ce_loss_weights: Optional[torch.Tensor] = None,
Run the BAGEL mixed-modal forward.
Stage 1 (
visual_gen=False) skips the flow-matching branch and MSE computation; Stage 2 activates both.ce_loss_weightsis accepted for data-pipeline compatibility but not consumed here - CE is returned per-token (reduction="none") and the trainer may apply weights downstream.Stage 2 inputs (
padded_latent,patchified_vae_latent_shapes,packed_latent_position_ids,packed_vae_token_indexes,packed_timesteps,mse_loss_indexes) are produced by the BAGEL collator when a pack contains t2i/edit samples.padded_latentis the VAE-encoded latent tensor; recipe must callvae_model.encodeon the rawpadded_imagesbefore forward; this module does not own the VAE.- Returns:
dict(ce=Tensor|None, mse=Tensor|None)- both can beNonewhen the pack has no samples of that loss type. Shape-stable with the BAGEL packed-training path.