nemo_automodel.components.models.bagel.modeling_siglip_navit#

SigLIP + NaViT vision tower for BAGEL.

NaViT-flavored differences from stock HF SigLIP:

  • Variable-resolution packing: forward takes packed_pixel_values (already patchified to shape (total_patches, C*P*P) after the conv->linear conversion) together with cu_seqlens / max_seqlen so that multiple images with different grids can share one forward call.

  • 2D rotary position embedding (RotaryEmbedding2D) applied on the first/second halves of the head dim, replacing the learnt absolute positional embedding table when config.rope=True.

  • Packed flash-attention (flash_attn_varlen_func) in place of the dense SiglipAttention.forward.

  • Conv2d -> Linear patch embedding swap via SiglipVisionEmbeddings.convert_conv2d_to_linear. Upstream calls this after loading the separately materialized ViT; AM calls it before loading the released BAGEL checkpoint because that checkpoint already stores the linear layout.

Class names and parameter attribute names preserve the BAGEL checkpoint layout so that ema.safetensors keys prefixed with vit_model.vision_model. load via the state-dict adapter without key surgery.

Module Contents#

Classes#

SiglipVisionConfig

SigLIP vision config with the NaViT rope flag added.

RotaryEmbedding2D

2D RoPE with separate height/width frequency tables.

SiglipVisionEmbeddings

NaViT patch embedder.

SiglipAttention

SigLIP attention projection container for the NaViT subclass.

SiglipFlashAttention2

Packed-sequence flash-attention variant with optional 2D RoPE.

SiglipMLP

SigLIP vision MLP block used inside the BAGEL NaViT encoder.

SiglipEncoderLayer

SigLIP NaViT encoder layer with packed flash attention.

SiglipEncoder

Stack of SigLIP NaViT encoder layers.

SiglipVisionTransformer

BAGEL SigLIP vision transformer over packed patch embeddings.

SiglipPreTrainedModel

Abstract weight-init base for SigLIP vision modules.

SiglipVisionModel

Top-level vision model. Stored at bagel_model.vit_model per BAGEL’s checkpoint layout.

Functions#

_flash_attn_varlen

rotate_half

apply_rotary_pos_emb

convert_conv2d_to_linear

Module-level helper mirroring BAGEL’s pretrain_unified_navit.py:525-526.

Data#

API#

nemo_automodel.components.models.bagel.modeling_siglip_navit.__all__#

[‘SiglipVisionConfig’, ‘RotaryEmbedding2D’, ‘SiglipVisionEmbeddings’, ‘SiglipAttention’, ‘SiglipFlas…

nemo_automodel.components.models.bagel.modeling_siglip_navit._flash_attn_varlen(*args, **kwargs)#
class nemo_automodel.components.models.bagel.modeling_siglip_navit.SiglipVisionConfig(
hidden_size: int = 768,
intermediate_size: int = 3072,
num_hidden_layers: int = 12,
num_attention_heads: int = 12,
num_channels: int = 3,
image_size: int = 224,
patch_size: int = 16,
hidden_act: str = 'gelu_pytorch_tanh',
layer_norm_eps: float = 1e-06,
attention_dropout: float = 0.0,
rope: bool = True,
**kwargs,
)#

Bases: transformers.SiglipVisionConfig

SigLIP vision config with the NaViT rope flag added.

Mirrors upstream modeling/bagel/siglip_navit.py::SiglipVisionConfig.

Initialization

model_type#

‘siglip_vision_model’

class nemo_automodel.components.models.bagel.modeling_siglip_navit.RotaryEmbedding2D(dim: int, max_h: int, max_w: int, base: int = 10000)#

Bases: torch.nn.Module

2D RoPE with separate height/width frequency tables.

Initialization

static _forward_one_side(grid: torch.Tensor, inv_freq: torch.Tensor)#
nemo_automodel.components.models.bagel.modeling_siglip_navit.rotate_half(x: torch.Tensor) torch.Tensor#
nemo_automodel.components.models.bagel.modeling_siglip_navit.apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
)#
class nemo_automodel.components.models.bagel.modeling_siglip_navit.SiglipVisionEmbeddings(
config: nemo_automodel.components.models.bagel.modeling_siglip_navit.SiglipVisionConfig,
)#

Bases: torch.nn.Module

NaViT patch embedder.

At construction time patch_embedding is a nn.Conv2d. The BAGEL model wrapper calls :meth:convert_conv2d_to_linear to swap it for an equivalent nn.Linear so the forward path can consume pre-patchified packed_pixel_values of shape (total_patches, C*P*P). For the released BAGEL-7B-MoT checkpoint, this conversion must happen before load because the checkpoint already stores the linear tensor shape.

Initialization

convert_conv2d_to_linear(
config: nemo_automodel.components.models.bagel.modeling_siglip_navit.SiglipVisionConfig,
meta: bool = False,
) None#

In-place swap Conv2d patch embedding for the mathematically equivalent Linear.

Called once, before checkpoint load, by the BAGEL model wrapper. After this runs, patch_embedding expects a 2-D (total_patches, C*P*P) input instead of 4-D (N, C, H, W).

forward(
packed_pixel_values: torch.FloatTensor,
packed_flattened_position_ids: torch.LongTensor,
) torch.Tensor#
nemo_automodel.components.models.bagel.modeling_siglip_navit.convert_conv2d_to_linear(
vit_model: SiglipVisionModel,
config: nemo_automodel.components.models.bagel.modeling_siglip_navit.SiglipVisionConfig,
meta: bool = False,
) None#

Module-level helper mirroring BAGEL’s pretrain_unified_navit.py:525-526.

Equivalent to vit_model.vision_model.embeddings.convert_conv2d_to_linear(config, meta).

class nemo_automodel.components.models.bagel.modeling_siglip_navit.SiglipAttention(
config: nemo_automodel.components.models.bagel.modeling_siglip_navit.SiglipVisionConfig,
)#

Bases: torch.nn.Module

SigLIP attention projection container for the NaViT subclass.

Keeping the projections directly on this module preserves the expected parameter names (q_proj, k_proj, v_proj, out_proj) for checkpoint loading. Forward is intentionally omitted because the packed NaViT variant is the only runtime path in this tree.

Initialization

class nemo_automodel.components.models.bagel.modeling_siglip_navit.SiglipFlashAttention2(
config: nemo_automodel.components.models.bagel.modeling_siglip_navit.SiglipVisionConfig,
)#

Bases: nemo_automodel.components.models.bagel.modeling_siglip_navit.SiglipAttention

Packed-sequence flash-attention variant with optional 2D RoPE.

Initialization

forward(
hidden_states: torch.Tensor,
cu_seqlens: torch.IntTensor,
max_seqlen: int,
cos_h: Optional[torch.Tensor] = None,
sin_h: Optional[torch.Tensor] = None,
cos_w: Optional[torch.Tensor] = None,
sin_w: Optional[torch.Tensor] = None,
**kwargs,
) torch.Tensor#
class nemo_automodel.components.models.bagel.modeling_siglip_navit.SiglipMLP(
config: nemo_automodel.components.models.bagel.modeling_siglip_navit.SiglipVisionConfig,
)#

Bases: torch.nn.Module

SigLIP vision MLP block used inside the BAGEL NaViT encoder.

Initialization

forward(hidden_states: torch.Tensor) torch.Tensor#
class nemo_automodel.components.models.bagel.modeling_siglip_navit.SiglipEncoderLayer(
config: nemo_automodel.components.models.bagel.modeling_siglip_navit.SiglipVisionConfig,
)#

Bases: torch.nn.Module

SigLIP NaViT encoder layer with packed flash attention.

Initialization

forward(
hidden_states: torch.Tensor,
cu_seqlens: torch.IntTensor,
max_seqlen: int,
cos_h: Optional[torch.Tensor] = None,
sin_h: Optional[torch.Tensor] = None,
cos_w: Optional[torch.Tensor] = None,
sin_w: Optional[torch.Tensor] = None,
) torch.Tensor#
class nemo_automodel.components.models.bagel.modeling_siglip_navit.SiglipEncoder(
config: nemo_automodel.components.models.bagel.modeling_siglip_navit.SiglipVisionConfig,
)#

Bases: torch.nn.Module

Stack of SigLIP NaViT encoder layers.

Initialization

forward(
inputs_embeds: torch.Tensor,
cu_seqlens: torch.IntTensor,
max_seqlen: int,
cos_h: Optional[torch.Tensor] = None,
sin_h: Optional[torch.Tensor] = None,
cos_w: Optional[torch.Tensor] = None,
sin_w: Optional[torch.Tensor] = None,
) torch.Tensor#
class nemo_automodel.components.models.bagel.modeling_siglip_navit.SiglipVisionTransformer(
config: nemo_automodel.components.models.bagel.modeling_siglip_navit.SiglipVisionConfig,
)#

Bases: torch.nn.Module

BAGEL SigLIP vision transformer over packed patch embeddings.

Initialization

forward(
packed_pixel_values: torch.Tensor,
packed_flattened_position_ids: torch.LongTensor,
cu_seqlens: torch.IntTensor,
max_seqlen: int,
) torch.Tensor#
class nemo_automodel.components.models.bagel.modeling_siglip_navit.SiglipPreTrainedModel#

Bases: transformers.modeling_utils.PreTrainedModel

Abstract weight-init base for SigLIP vision modules.

config_class#

None

base_model_prefix#

‘siglip’

supports_gradient_checkpointing#

True

_no_split_modules#

[‘SiglipVisionEmbeddings’, ‘SiglipEncoderLayer’]

_supports_flash_attn_2#

True

_supports_sdpa#

False

_init_weights(module: torch.nn.Module) None#
class nemo_automodel.components.models.bagel.modeling_siglip_navit.SiglipVisionModel(
config: nemo_automodel.components.models.bagel.modeling_siglip_navit.SiglipVisionConfig,
)#

Bases: nemo_automodel.components.models.bagel.modeling_siglip_navit.SiglipPreTrainedModel

Top-level vision model. Stored at bagel_model.vit_model per BAGEL’s checkpoint layout.

Initialization

config_class#

None

main_input_name#

‘packed_pixel_values’

get_input_embeddings() torch.nn.Module#
forward(
packed_pixel_values: torch.Tensor,
packed_flattened_position_ids: torch.LongTensor,
cu_seqlens: torch.IntTensor,
max_seqlen: int,
) torch.Tensor#