bridge.models.stepfun.modelling_step37.utils#

Vision-encoder building blocks for Step3.7.

The module mirrors the upstream HuggingFace vision_encoder.py shipped inside stepfun-ai/step3p7_flash_bf16: a 2D-RoPE Perception-Encoder G/14 ViT with LayerScale-gated residuals. All attribute names match the reference implementation so the safetensors weights can be loaded by name with no renaming required.

Module Contents#

Classes#

EncoderRope2D

Cacheable 2D rotary positional embedding (matches HF reference).

EncoderLayerScale

Per-channel residual scaling (γ stored as ls_{1,2}.gamma in the HF checkpoint).

EncoderMLP

Feed-forward network used inside each transformer block.

EncoderVisionAttention

Multi-head self attention with optional 2D RoPE (matches HF reference).

EncoderVisionBlock

A single PE-G/14 Vision Transformer block.

EncoderVisionTransformer

Stack of PE-G/14 encoder blocks (vision_model.transformer.resblocks).

Functions#

rotate_half

Rotate last dimension halves (used by RoPE).

apply_rotary_emb

Apply 2D rotary embeddings to queries / keys.

API#

bridge.models.stepfun.modelling_step37.utils.rotate_half(x: torch.Tensor) torch.Tensor#

Rotate last dimension halves (used by RoPE).

bridge.models.stepfun.modelling_step37.utils.apply_rotary_emb(
freqs: torch.Tensor,
t: torch.Tensor,
start_index: int = 0,
scale: float = 1.0,
seq_dim: int = -2,
) torch.Tensor#

Apply 2D rotary embeddings to queries / keys.

class bridge.models.stepfun.modelling_step37.utils.EncoderRope2D(
dim: int,
max_grid_height: int,
max_grid_width: int,
use_cls_token: bool = False,
theta: Union[int, float] = 10000,
max_freq: int = 10,
num_freqs: int = 1,
theta_rescale_factor: float = 1.0,
)#

Bases: torch.nn.Module

Cacheable 2D rotary positional embedding (matches HF reference).

Initialization

_compute_inv_freq(
base: Union[int, float],
dim: int,
) torch.Tensor#
_compute_freqs(t: torch.Tensor, inv_freq: torch.Tensor)#
_compute_2d_freqs() torch.Tensor#
forward(
q: torch.Tensor,
k: torch.Tensor,
grid_hw: Tuple[int, int],
)#
class bridge.models.stepfun.modelling_step37.utils.EncoderLayerScale(dim: int, init_values: float)#

Bases: torch.nn.Module

Per-channel residual scaling (γ stored as ls_{1,2}.gamma in the HF checkpoint).

Initialization

forward(hidden_states: torch.Tensor) torch.Tensor#
class bridge.models.stepfun.modelling_step37.utils.EncoderMLP(hidden_size: int, intermediate_size: int, hidden_act: str)#

Bases: torch.nn.Module

Feed-forward network used inside each transformer block.

Initialization

forward(hidden_states: torch.Tensor) torch.Tensor#
class bridge.models.stepfun.modelling_step37.utils.EncoderVisionAttention(
hidden_size: int,
num_heads: int,
max_grid_height: int,
max_grid_width: int,
use_cls_token: bool = False,
use_rope2d: bool = True,
rope_theta: Union[int, float] = 10000,
rope_max_freq: int = 10,
rope_num_freqs: int = 1,
rope_theta_rescale_factor: float = 1.0,
rope_freqs_for: Literal[lang, pixel, constant] = 'lang',
)#

Bases: torch.nn.Module

Multi-head self attention with optional 2D RoPE (matches HF reference).

Initialization

forward(
hidden_states: torch.Tensor,
grid_hw: Tuple[int, int],
) torch.Tensor#
class bridge.models.stepfun.modelling_step37.utils.EncoderVisionBlock(
hidden_size: int,
num_heads: int,
mlp_ratio: float,
hidden_act: str,
layer_norm_eps: float,
ls_init_value: Optional[float] = None,
max_grid_height: Optional[int] = None,
max_grid_width: Optional[int] = None,
use_cls_token: bool = False,
use_rope2d: bool = True,
rope_kwargs: Optional[dict] = None,
)#

Bases: torch.nn.Module

A single PE-G/14 Vision Transformer block.

Initialization

forward(
hidden_states: torch.Tensor,
grid_hw: Tuple[int, int],
) torch.Tensor#
class bridge.models.stepfun.modelling_step37.utils.EncoderVisionTransformer(
embed_dim: int,
depth: int,
num_heads: int,
mlp_ratio: float,
hidden_act: str,
layer_norm_eps: float,
ls_init_value: Optional[float] = None,
max_grid_height: Optional[int] = None,
max_grid_width: Optional[int] = None,
use_cls_token: bool = False,
use_rope2d: bool = True,
rope_kwargs: Optional[dict] = None,
)#

Bases: torch.nn.Module

Stack of PE-G/14 encoder blocks (vision_model.transformer.resblocks).

Initialization

forward(
hidden_states: torch.Tensor,
grid_hw: Tuple[int, int],
) torch.Tensor#