bridge.models.stepfun.step35_provider#

Step-3.5-Flash Model Provider for Megatron-Core.

Step-3.5-Flash uses a hybrid attention pattern: full-attention layers (num_attention_heads=64) interleave with sliding-attention layers (num_attention_heads=96). The HF config carries the per-layer attention type in layer_types and the sliding-layer shape overrides in attention_other_setting.

This provider surfaces layer_types (per-layer attention type) as a dataclass field and attention_other_setting as the enable-flag for the sliding-attention path. The actual sliding-layer shape values are forwarded through the sliding_attention_setting field populated by Step35Bridge.provider_bridge. The custom Step35DecoderLayer reads all three at construction time to decide, on a per-layer basis, whether to use the global config or the sliding-attention overrides when building its sub-modules.

Module Contents#

Classes#

Step35DecoderLayer

Hybrid full/sliding attention decoder layer for Step-3.5-Flash.

Step35ModelProvider

Model provider for Step-3.5-Flash.

API#

class bridge.models.stepfun.step35_provider.Step35DecoderLayer(
config: megatron.core.transformer.transformer_config.TransformerConfig,
submodules: megatron.core.transformer.transformer_layer.TransformerLayerSubmodules,
layer_number: int = 1,
hidden_dropout: Optional[float] = None,
pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
vp_stage: Optional[int] = None,
is_mtp_layer: bool = False,
add_layer_offset: bool = True,
pp_layer_offset: Optional[int] = None,
)#

Bases: megatron.core.transformer.transformer_layer.TransformerLayer

Hybrid full/sliding attention decoder layer for Step-3.5-Flash.

On construction the layer resolves a global 0-indexed layer_idx:

  • For MTP layers, layer_idx is offset after the main decoder layers so per-layer RoPE and attention-type lists can include MTP entries.

  • When add_layer_offset=False, layer_idx = layer_number - 1.

  • Otherwise layer_idx = layer_number + get_transformer_layer_offset( config, vp_stage, pp_rank) - 1, so PP>1 still maps correctly.

It then looks up config.layer_types[layer_idx]. If the entry is "sliding_attention" (and config.attention_other_setting is set as the enable flag), the config is deep-copied and the shape-related fields are overridden from config.sliding_attention_setting before delegating to TransformerLayer.__init__. The overridden config is what every downstream sub-module (self_attention, linear_qkv with the per-head g_proj gate expanded into Megatron-Core’s gated-attention layout, and linear_proj) ends up reading, so each layer is sized correctly without changing Megatron-LM core.

Fields read from config.sliding_attention_setting (HF key on the left, TransformerConfig attribute on the right):

  • rotary_percent -> rotary_percent

  • num_attention_heads -> num_attention_heads

  • num_query_groups -> num_query_groups

  • head_dim -> kv_channels

Implementation notes:

  • The spec-builder must keep layer_types indexed by the global 0-indexed layer id (same constraint as rotary_base_per_layer).

  • Layers whose resolved layer_idx falls outside layer_types fall through to the global config.

Initialization

class bridge.models.stepfun.step35_provider.Step35ModelProvider#

Bases: megatron.bridge.models.gpt_provider.GPTModelProvider

Model provider for Step-3.5-Flash.

Adds Step3.5-specific fields on top of GPTModelProvider:

  • layer_types: 0-indexed list of attention types (e.g. "full_attention" / "sliding_attention"), one entry per main decoder layer. Read by Step35DecoderLayer to decide whether the current layer is a sliding-attention layer.

  • attention_other_setting: HF dict that enables and describes the sliding-attention override.

  • sliding_attention_setting: normalized Megatron-facing shape overrides derived from attention_other_setting.

  • head_wise_attn_gate: whether to map HF’s per-head g_proj gate through Megatron-Core’s attention_output_gate path.

These fields are populated from the HF config inside Step35Bridge.provider_bridge.

layer_types: list[str] | None#

None

attention_other_setting: dict[str, Any] | None#

None

sliding_attention_setting: dict[str, Any] | None#

None

rotary_base_per_layer: list[float] | None#

None

head_wise_attn_gate: Optional[bool]#

False