bridge.models.ernie_vl.modeling_ernie45_vl.vision_model#

Megatron-Core native Vision Transformer for ERNIE 4.5 VL.

This module implements the ERNIE 4.5 VL DFN-style ViT using Megatron-Core TransformerBlock infrastructure instead of the HuggingFace implementation. This enables TP-native attention and MLP layers for better distributed training performance.

Architecture (matching HF DFNRopeVisionTransformerPretrainedModel): - PatchEmbed: nn.Linear(C * P * P, embed_dim, bias=False) - 2D RoPE: Non-interleaved rotate_half with spatial_merge_size reordering - 32x TransformerLayer (TE-backed): - LayerNorm(1280, eps=1e-6) -> QKV(1280, 3*1280, bias=True) -> Attention -> Proj - LayerNorm(1280, eps=1e-6) -> FC1(1280, 5120) -> quick_gelu -> FC2(5120, 1280) - Final LayerNorm(1280, eps=1e-6) - Per-image packed sequence attention via PackedSeqParams (thd format)

Key differences from Qwen3VL MG ViT: - PatchEmbed uses nn.Linear (not Conv3d) on pre-flattened patches - No positional embedding interpolation (ERNIE uses pure 2D RoPE) - No deepstack feature extraction - Non-interleaved RoPE (rotate_half style, rotary_interleaved=False) - No PatchMerger (merging is done by the resampler)

Module Contents#

Classes#

ErnieVisionPatchEmbed

Patch embedding for ERNIE 4.5 VL ViT.

ErnieVisionRotaryEmbedding

1D rotary embedding frequency table for ERNIE ViT 2D RoPE.

ErnieVLVisionModel

Megatron-Core native ERNIE 4.5 VL Vision Transformer.

API#

class bridge.models.ernie_vl.modeling_ernie45_vl.vision_model.ErnieVisionPatchEmbed(
in_channels: int = 3,
patch_size: int = 14,
embed_dim: int = 1280,
)#

Bases: torch.nn.Module

Patch embedding for ERNIE 4.5 VL ViT.

Unlike Qwen3VL which uses Conv3d on raw image tensors, ERNIE’s processor pre-flattens each patch into a vector of size [C * patch_size^2] = [588], so patch embedding is a simple linear projection.

Parameters:
  • in_channels – Number of input channels (default 3).

  • patch_size – Patch size in pixels (default 14).

  • embed_dim – Embedding dimension (default 1280).

Initialization

forward(hidden_states: torch.Tensor) torch.Tensor#

Project pre-flattened patches to embedding space.

Parameters:

hidden_states – [total_patches, C * patch_size^2] (e.g., [N, 588])

Returns:

[total_patches, embed_dim] (e.g., [N, 1280])

class bridge.models.ernie_vl.modeling_ernie45_vl.vision_model.ErnieVisionRotaryEmbedding(dim: int, theta: float = 10000.0)#

Bases: torch.nn.Module

1D rotary embedding frequency table for ERNIE ViT 2D RoPE.

Computes a frequency table of shape [max_seqlen, dim//2] which is then indexed by 2D (H, W) position IDs to produce per-token RoPE embeddings.

This matches HF’s VisionRotaryEmbedding in the ERNIE 4.5 VL model.

Parameters:
  • dim – Half of the per-head dimension (head_dim // 2). For ERNIE ViT: head_dim = 1280 / 16 = 80, so dim = 40.

  • theta – RoPE base frequency (default 10000.0).

Initialization

forward(seqlen: int) torch.Tensor#

Compute frequency table for positions 0..seqlen-1.

Parameters:

seqlen – Maximum sequence length to compute frequencies for.

Returns:

Tensor of shape [seqlen, dim] containing outer product of position indices and inverse frequencies.

class bridge.models.ernie_vl.modeling_ernie45_vl.vision_model.ErnieVLVisionModel(
transformer_config: megatron.bridge.models.ernie_vl.modeling_ernie45_vl.vision_transformer_config.ErnieVisionTransformerConfig,
transformer_layer_spec: megatron.core.transformer.spec_utils.ModuleSpec,
)#

Bases: megatron.core.models.common.vision_module.vision_module.VisionModule

Megatron-Core native ERNIE 4.5 VL Vision Transformer.

Implements the DFN-style ViT with 2D RoPE using MCore TransformerBlock for TP-native distributed training.

Architecture: 1. PatchEmbed (nn.Linear, replicated) 2. 2D RoPE computation (per-image H/W position lookup with spatial merge reordering) 3. TransformerBlock (32 ViT layers with TE modules) 4. Final LayerNorm

Unlike the HF-wrapped version, this implementation: - Uses TE-backed attention and MLP layers for TP support - Leverages MCore’s PackedSeqParams for per-image variable-length attention - Enables activation recomputation through TransformerBlock

Parameters:
  • transformer_config – ErnieVisionTransformerConfig with ViT hyperparameters.

  • transformer_layer_spec – ModuleSpec for each ViT transformer layer.

Initialization

set_input_tensor(input_tensor: torch.Tensor) None#

Set input tensor (for pipeline parallelism, currently not used for ViT).

rot_pos_emb(grid_thw: torch.Tensor) torch.Tensor#

Compute 2D RoPE positional embeddings for all tokens.

For each image/video frame, computes (H, W) position IDs with spatial_merge_size reordering (grouping merge_size x merge_size patches together), then looks up the frequency table.

The spatial merge reordering ensures that patches within each spatial merge unit (2x2 by default) are consecutive in the sequence, matching the resampler’s spatial pooling pattern.

Parameters:

grid_thw – [num_images, 3] tensor of (T, H, W) grid dimensions for each image/video.

Returns:

Tensor of shape [total_tokens, head_dim] containing the concatenated cos/sin frequencies for 2D RoPE.

build_packed_seq_params(
grid_thw: torch.Tensor,
) megatron.core.packed_seq_params.PackedSeqParams#

Build PackedSeqParams for per-image variable-length attention.

Each frame in each image/video is treated as a separate sequence for attention computation. This enables per-image attention without cross-image contamination.

Parameters:

grid_thw – [num_images, 3] tensor of (T, H, W) grid dimensions.

Returns:

PackedSeqParams with cu_seqlens for thd-format attention.

forward(
hidden_states: torch.Tensor,
grid_thw: torch.Tensor,
inference_params: Optional[megatron.core.InferenceParams] = None,
extra_block_kwargs: Optional[dict] = None,
) torch.Tensor#

Forward pass of the ERNIE ViT.

Parameters:
  • hidden_states – Pre-flattened pixel patches [total_patches, CPP].

  • grid_thw – [num_images, 3] tensor of (T, H, W) grid dimensions.

  • inference_params – Inference parameters (currently unused for ViT).

  • extra_block_kwargs – Extra kwargs to pass to TransformerBlock.

Returns:

Vision features of shape [total_patches, hidden_size].