bridge.models.stepfun.step35_provider#
Step-3.5-Flash Model Provider for Megatron-Core.
Step-3.5-Flash uses a hybrid attention pattern: full-attention layers
(num_attention_heads=64) interleave with sliding-attention layers
(num_attention_heads=96). The HF config carries the per-layer attention type
in layer_types and the sliding-layer shape overrides in
attention_other_setting.
This provider surfaces layer_types (per-layer attention type) as a
dataclass field and attention_other_setting as the enable-flag for the
sliding-attention path. The actual sliding-layer shape values are forwarded
through the sliding_attention_setting field populated by
Step35Bridge.provider_bridge. The custom Step35DecoderLayer reads
all three at construction time to decide, on a per-layer basis, whether to
use the global config or the sliding-attention overrides when building its
sub-modules.
Module Contents#
Classes#
Hybrid full/sliding attention decoder layer for Step-3.5-Flash. |
|
Model provider for Step-3.5-Flash. |
API#
- class bridge.models.stepfun.step35_provider.Step35DecoderLayer(
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- submodules: megatron.core.transformer.transformer_layer.TransformerLayerSubmodules,
- layer_number: int = 1,
- hidden_dropout: Optional[float] = None,
- pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
- vp_stage: Optional[int] = None,
- is_mtp_layer: bool = False,
- add_layer_offset: bool = True,
- pp_layer_offset: Optional[int] = None,
Bases:
megatron.core.transformer.transformer_layer.TransformerLayerHybrid full/sliding attention decoder layer for Step-3.5-Flash.
On construction the layer resolves a global 0-indexed
layer_idx:For MTP layers,
layer_idxis offset after the main decoder layers so per-layer RoPE and attention-type lists can include MTP entries.When
add_layer_offset=False,layer_idx = layer_number - 1.Otherwise
layer_idx = layer_number + get_transformer_layer_offset( config, vp_stage, pp_rank) - 1, so PP>1 still maps correctly.
It then looks up
config.layer_types[layer_idx]. If the entry is"sliding_attention"(andconfig.attention_other_settingis set as the enable flag), the config is deep-copied and the shape-related fields are overridden fromconfig.sliding_attention_settingbefore delegating toTransformerLayer.__init__. The overridden config is what every downstream sub-module (self_attention,linear_qkvwith the per-headg_projgate expanded into Megatron-Core’s gated-attention layout, andlinear_proj) ends up reading, so each layer is sized correctly without changing Megatron-LM core.Fields read from
config.sliding_attention_setting(HF key on the left,TransformerConfigattribute on the right):rotary_percent->rotary_percentnum_attention_heads->num_attention_headsnum_query_groups->num_query_groupshead_dim->kv_channels
Implementation notes:
The spec-builder must keep
layer_typesindexed by the global 0-indexed layer id (same constraint asrotary_base_per_layer).Layers whose resolved
layer_idxfalls outsidelayer_typesfall through to the global config.
Initialization
- class bridge.models.stepfun.step35_provider.Step35ModelProvider#
Bases:
megatron.bridge.models.gpt_provider.GPTModelProviderModel provider for Step-3.5-Flash.
Adds Step3.5-specific fields on top of
GPTModelProvider:layer_types: 0-indexed list of attention types (e.g."full_attention"/"sliding_attention"), one entry per main decoder layer. Read byStep35DecoderLayerto decide whether the current layer is a sliding-attention layer.attention_other_setting: HF dict that enables and describes the sliding-attention override.sliding_attention_setting: normalized Megatron-facing shape overrides derived fromattention_other_setting.head_wise_attn_gate: whether to map HF’s per-headg_projgate through Megatron-Core’sattention_output_gatepath.
These fields are populated from the HF config inside
Step35Bridge.provider_bridge.- layer_types: list[str] | None#
None
- attention_other_setting: dict[str, Any] | None#
None
- sliding_attention_setting: dict[str, Any] | None#
None
- rotary_base_per_layer: list[float] | None#
None
- head_wise_attn_gate: Optional[bool]#
False