nemo_automodel.components.models.bagel.model

View as Markdown

Top-level BAGEL model.

The module supports Stage 1 understanding-only CE and Stage 2 joint understanding and visual generation with flow-matching MSE. The VAE itself stays outside this module tree; the recipe passes VAE-encoded latents to forward.

Top-level module is nn.Module (not PreTrainedModel) + mixed-in HFCheckpointingMixin — matches the llava-onevision pattern and avoids the FSDP double-root issue that bites us with PreTrainedModel-derived roots.

Module Contents

Classes

NameDescription
BagelForUnifiedMultimodalBAGEL mixed-modal LLM wrapper for understanding and optional generation.
BagelModelPlain container for the three BAGEL submodules.

Functions

NameDescription
_convert_patch_embedding_for_packed_vitSwap SigLIP patch embedding to Linear for BAGEL packed pixel inputs.
_prepare_config_for_stageApply BAGEL stage/checkpoint config fixes before module construction.
_stage_to_intNormalize a BAGEL training stage value to 1 or 2.

Data

logger

API

class nemo_automodel.components.models.bagel.model.BagelForUnifiedMultimodal(
config: nemo_automodel.components.models.bagel.configuration.BagelConfig
)

Bases: HFCheckpointingMixin, Module

BAGEL mixed-modal LLM wrapper for understanding and optional generation.

visual_gen=False gives the Stage 1 understanding-only path. Stage 2 sets visual_gen=True and uses the MoT *_moe_gen parameter siblings, VAE latents prepared by the recipe, and the flow-matching MSE head.

hidden_size
= config.text_config.hidden_size
model
= BagelModel(config)
num_heads
= config.text_config.num_attention_heads
state_dict_adapter
use_moe
nemo_automodel.components.models.bagel.model.BagelForUnifiedMultimodal.forward(
sequence_length: int,
packed_text_ids: torch.LongTensor,
packed_text_indexes: torch.LongTensor,
sample_lens: typing.List[int],
packed_position_ids: torch.LongTensor,
nested_attention_masks: typing.Optional[typing.List[torch.Tensor]] = None,
split_lens: typing.Optional[typing.List[int]] = None,
attn_modes: typing.Optional[typing.List[str]] = None,
packed_vit_tokens: typing.Optional[torch.Tensor] = None,
packed_vit_token_indexes: typing.Optional[torch.LongTensor] = None,
packed_vit_position_ids: typing.Optional[torch.LongTensor] = None,
vit_token_seqlens: typing.Optional[torch.Tensor] = None,
padded_latent: typing.Optional[torch.Tensor] = None,
patchified_vae_latent_shapes: typing.Optional[typing.List[typing.Tuple[int, int]]] = None,
packed_latent_position_ids: typing.Optional[torch.LongTensor] = None,
packed_vae_token_indexes: typing.Optional[torch.LongTensor] = None,
packed_timesteps: typing.Optional[torch.Tensor] = None,
mse_loss_indexes: typing.Optional[torch.Tensor] = None,
ce_loss_indexes: typing.Optional[torch.Tensor] = None,
packed_label_ids: typing.Optional[torch.Tensor] = None,
ce_loss_weights: typing.Optional[torch.Tensor] = None
) -> typing.Dict[str, typing.Optional[torch.Tensor]]

Run the BAGEL mixed-modal forward.

Stage 1 (visual_gen=False) skips the flow-matching branch and MSE computation; Stage 2 activates both. ce_loss_weights is accepted for data-pipeline compatibility but not consumed here - CE is returned per-token (reduction="none") and the trainer may apply weights downstream.

Stage 2 inputs (padded_latent, patchified_vae_latent_shapes, packed_latent_position_ids, packed_vae_token_indexes, packed_timesteps, mse_loss_indexes) are produced by the BAGEL collator when a pack contains t2i/edit samples. padded_latent is the VAE-encoded latent tensor; recipe must call vae_model.encode on the raw padded_images before forward; this module does not own the VAE.

Returns: Dict[str, Optional[torch.Tensor]]

dict(ce=Tensor|None, mse=Tensor|None) - both can be None

nemo_automodel.components.models.bagel.model.BagelForUnifiedMultimodal.from_pretrained(
pretrained_model_name_or_path: typing.Union[str, os.PathLike],
stage: typing.Union[int, str] = 1,
strict: bool = False,
kwargs: typing.Any = {}
) -> 'BagelForUnifiedMultimodal'
classmethod

Load a BAGEL-7B-MoT checkpoint directory into this class.

Reads config.json via :meth:BagelConfig.from_pretrained, constructs an empty model, and then loads ema.safetensors filtered by :class:BagelStateDictAdapter. Stage 2 VAE weights are loaded by the recipe because the VAE is not owned by this module tree.

Parameters:

pretrained_model_name_or_path
Union[str, os.PathLike]

Directory containing the HF-layout BAGEL checkpoint.

stage
Union[int, str]Defaults to 1

1 (UND only) or 2 (UND + GEN). Strings "stage1" / "stage2" are also accepted.

strict
boolDefaults to False

If True, raise on state-dict keys that don’t match the adapter patterns. Defaults to False for compatibility with checkpoint sidecar files.

**kwargs
AnyDefaults to {}

Forwarded to BagelConfig.from_pretrained.

Returns: 'BagelForUnifiedMultimodal'

A fully-initialized BagelForUnifiedMultimodal with weights

nemo_automodel.components.models.bagel.model.BagelForUnifiedMultimodal.get_input_embeddings() -> torch.nn.Module
nemo_automodel.components.models.bagel.model.BagelForUnifiedMultimodal.get_output_embeddings() -> torch.nn.Module
nemo_automodel.components.models.bagel.model.BagelForUnifiedMultimodal.initialize_weights() -> None

Initialize BAGEL weights after AM materializes a from_config model.

BagelForUnifiedMultimodal is an nn.Module root, not a HF PreTrainedModel root. AM’s meta-device from_config path materializes parameters after sharding and then calls this method. Delegate Qwen/SigLIP subtrees to their HF-style initializers, then initialize BAGEL-only connector and generation modules.

nemo_automodel.components.models.bagel.model.BagelForUnifiedMultimodal.supports_config(
config: typing.Any
) -> bool
classmethod

Return True if this custom class supports config.

class nemo_automodel.components.models.bagel.model.BagelModel(
config: nemo_automodel.components.models.bagel.configuration.BagelConfig
)

Bases: Module

Plain container for the three BAGEL submodules.

Attribute names (language_model, vit_model, connector, vit_pos_embed) match the checkpoint layout so the state-dict adapter maps identity. There’s no forward logic here - this class exists so that FSDP / state-dict tooling sees the expected tree structure without being confused by the HFCheckpointingMixin root.

When config.visual_gen=True (Stage 2), we additionally attach the generation-side siblings (time_embedder, vae2llm, llm2vae, latent_pos_embed) so the flow-matching head is ready to run. The VAE model itself is NOT owned here; the recipe keeps it separate (frozen, inference-only) and passes already-encoded latents into BagelForUnifiedMultimodal.forward.

connector
language_model
= Qwen2ForCausalLM(config.text_config)
latent_downsample
= downsample * config.latent_patch_size
latent_patch_size
= config.latent_patch_size
latent_pos_embed
llm2vae
max_latent_size
= config.max_latent_size
patch_latent_dim
= config.latent_patch_size ** 2 * latent_channel
time_embedder
timestep_shift
= config.timestep_shift
vae2llm
vit_model
= SiglipVisionModel(config.vision_config)
vit_pos_embed
nemo_automodel.components.models.bagel.model._convert_patch_embedding_for_packed_vit(
model: 'BagelModel',
config: nemo_automodel.components.models.bagel.configuration.BagelConfig
) -> None

Swap SigLIP patch embedding to Linear for BAGEL packed pixel inputs.

nemo_automodel.components.models.bagel.model._prepare_config_for_stage(
config: nemo_automodel.components.models.bagel.configuration.BagelConfig
) -> None

Apply BAGEL stage/checkpoint config fixes before module construction.

AutoModel instantiates custom models as model_cls(config) and lets the common checkpointer load weights later. BAGEL therefore needs the same stage-dependent config mutations that its direct from_pretrained path used to do before BagelModel is built.

nemo_automodel.components.models.bagel.model._stage_to_int(
stage: typing.Union[int, str]
) -> int

Normalize a BAGEL training stage value to 1 or 2.

nemo_automodel.components.models.bagel.model.logger = logging.getLogger(__name__)