bridge.models.stepfun.modelling_step37.utils#
Vision-encoder building blocks for Step3.7.
The module mirrors the upstream HuggingFace vision_encoder.py shipped
inside stepfun-ai/step3p7_flash_bf16: a 2D-RoPE Perception-Encoder G/14
ViT with LayerScale-gated residuals. All attribute names match the reference
implementation so the safetensors weights can be loaded by name with no
renaming required.
Module Contents#
Classes#
Cacheable 2D rotary positional embedding (matches HF reference). |
|
Per-channel residual scaling (γ stored as |
|
Feed-forward network used inside each transformer block. |
|
Multi-head self attention with optional 2D RoPE (matches HF reference). |
|
A single PE-G/14 Vision Transformer block. |
|
Stack of PE-G/14 encoder blocks ( |
Functions#
Rotate last dimension halves (used by RoPE). |
|
Apply 2D rotary embeddings to queries / keys. |
API#
- bridge.models.stepfun.modelling_step37.utils.rotate_half(x: torch.Tensor) torch.Tensor#
Rotate last dimension halves (used by RoPE).
- bridge.models.stepfun.modelling_step37.utils.apply_rotary_emb(
- freqs: torch.Tensor,
- t: torch.Tensor,
- start_index: int = 0,
- scale: float = 1.0,
- seq_dim: int = -2,
Apply 2D rotary embeddings to queries / keys.
- class bridge.models.stepfun.modelling_step37.utils.EncoderRope2D(
- dim: int,
- max_grid_height: int,
- max_grid_width: int,
- use_cls_token: bool = False,
- theta: Union[int, float] = 10000,
- max_freq: int = 10,
- num_freqs: int = 1,
- theta_rescale_factor: float = 1.0,
Bases:
torch.nn.ModuleCacheable 2D rotary positional embedding (matches HF reference).
Initialization
- _compute_inv_freq(
- base: Union[int, float],
- dim: int,
- _compute_freqs(t: torch.Tensor, inv_freq: torch.Tensor)#
- _compute_2d_freqs() torch.Tensor#
- forward(
- q: torch.Tensor,
- k: torch.Tensor,
- grid_hw: Tuple[int, int],
- class bridge.models.stepfun.modelling_step37.utils.EncoderLayerScale(dim: int, init_values: float)#
Bases:
torch.nn.ModulePer-channel residual scaling (γ stored as
ls_{1,2}.gammain the HF checkpoint).Initialization
- forward(hidden_states: torch.Tensor) torch.Tensor#
- class bridge.models.stepfun.modelling_step37.utils.EncoderMLP(hidden_size: int, intermediate_size: int, hidden_act: str)#
Bases:
torch.nn.ModuleFeed-forward network used inside each transformer block.
Initialization
- forward(hidden_states: torch.Tensor) torch.Tensor#
- class bridge.models.stepfun.modelling_step37.utils.EncoderVisionAttention(
- hidden_size: int,
- num_heads: int,
- max_grid_height: int,
- max_grid_width: int,
- use_cls_token: bool = False,
- use_rope2d: bool = True,
- rope_theta: Union[int, float] = 10000,
- rope_max_freq: int = 10,
- rope_num_freqs: int = 1,
- rope_theta_rescale_factor: float = 1.0,
- rope_freqs_for: Literal[lang, pixel, constant] = 'lang',
Bases:
torch.nn.ModuleMulti-head self attention with optional 2D RoPE (matches HF reference).
Initialization
- forward(
- hidden_states: torch.Tensor,
- grid_hw: Tuple[int, int],
- class bridge.models.stepfun.modelling_step37.utils.EncoderVisionBlock(
- hidden_size: int,
- num_heads: int,
- mlp_ratio: float,
- hidden_act: str,
- layer_norm_eps: float,
- ls_init_value: Optional[float] = None,
- max_grid_height: Optional[int] = None,
- max_grid_width: Optional[int] = None,
- use_cls_token: bool = False,
- use_rope2d: bool = True,
- rope_kwargs: Optional[dict] = None,
Bases:
torch.nn.ModuleA single PE-G/14 Vision Transformer block.
Initialization
- forward(
- hidden_states: torch.Tensor,
- grid_hw: Tuple[int, int],
- class bridge.models.stepfun.modelling_step37.utils.EncoderVisionTransformer(
- embed_dim: int,
- depth: int,
- num_heads: int,
- mlp_ratio: float,
- hidden_act: str,
- layer_norm_eps: float,
- ls_init_value: Optional[float] = None,
- max_grid_height: Optional[int] = None,
- max_grid_width: Optional[int] = None,
- use_cls_token: bool = False,
- use_rope2d: bool = True,
- rope_kwargs: Optional[dict] = None,
Bases:
torch.nn.ModuleStack of PE-G/14 encoder blocks (
vision_model.transformer.resblocks).Initialization
- forward(
- hidden_states: torch.Tensor,
- grid_hw: Tuple[int, int],