bridge.models.stepfun.step35_provider#

Step-3.5-Flash Model Provider for Megatron-Core.

Step-3.5-Flash uses a hybrid attention pattern: full-attention layers (num_attention_heads=64) interleave with sliding-attention layers (num_attention_heads=96). The HF config carries the per-layer attention type in layer_types and the sliding-layer shape overrides in attention_other_setting.

This provider surfaces layer_types (per-layer attention type) as a dataclass field and attention_other_setting as the enable-flag for the sliding-attention path. The actual sliding-layer shape values are forwarded through the sliding_attention_setting field populated by Step35Bridge.provider_bridge. The custom Step35DecoderLayer reads all three at construction time to decide, on a per-layer basis, whether to use the global config or the sliding-attention overrides when building its sub-modules.

Module Contents#

Classes#

Step35DecoderLayer

Hybrid full/sliding attention decoder layer for Step-3.5-Flash.

Step35SharedExpertMLP

Shared-expert MLP for Step-3.5 honoring a per-shared-expert SwiGLU clamp.

Step35ModelProvider

Model provider for Step-3.5-Flash.

API#

class bridge.models.stepfun.step35_provider.Step35DecoderLayer(
config: megatron.core.transformer.transformer_config.TransformerConfig,
submodules: megatron.core.transformer.transformer_layer.TransformerLayerSubmodules,
layer_number: int = 1,
hidden_dropout: Optional[float] = None,
pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
vp_stage: Optional[int] = None,
is_mtp_layer: bool = False,
add_layer_offset: bool = True,
pp_layer_offset: Optional[int] = None,
name: str | None = None,
)#

Bases: megatron.core.transformer.transformer_layer.TransformerLayer

Hybrid full/sliding attention decoder layer for Step-3.5-Flash.

Resolves a global 0-indexed layer_idx (MTP layers are offset past the main decoder; otherwise layer_number + pp_offset - 1, or layer_number - 1 when add_layer_offset=False) and uses it to perform three per-layer config lookups before delegating to TransformerLayer.__init__:

  1. RoPE percentage — rotary_percents[layer_idx] overrides config.rotary_percent (Step-3.5 alternates 0.5 / 1.0). Out of range → reset to 1.0 (the sliding-layer default), so MTP / unconfigured layers don’t inherit the previous layer’s value.

  2. Attention type — when layer_types[layer_idx] == "sliding_attention" and attention_other_setting is truthy, the config is deep-copied and window_size / num_attention_heads / num_query_groups / kv_channels are overridden from sliding_attention_setting (already in Megatron-facing names; the HF→mcore renaming happens in Step35Bridge.provider_bridge). rotary_percent is not touched here.

  3. SwiGLU clamp — swiglu_limits[layer_idx] / swiglu_limits_shared[layer_idx] overwrite activation_func_clamp_value / activation_func_clamp_value_shared. Out of range → skipped, keeping the global value.

All lookups are bounds-checked rather than raising. The spec-builder must keep these lists (and rotary_base_per_layer) indexed by the global 0-indexed layer id.

Initialization

class bridge.models.stepfun.step35_provider.Step35SharedExpertMLP#

Bases: megatron.core.transformer.moe.shared_experts.SharedExpertMLP

Shared-expert MLP for Step-3.5 honoring a per-shared-expert SwiGLU clamp.

SharedExpertMLP.__init__ private-deepcopies its config so the shared expert can mutate ffn_hidden_size without affecting the routed experts. Step-3.5 sets a separate per-layer activation_func_clamp_value_shared_expert field on the config in Step35DecoderLayer.__init__ (with documented fallback to activation_func_clamp_value when it is None). This subclass surfaces that field to MLP.forward — which only reads self.config.activation_func_clamp_value for SwiGLU clamping — by swapping the value on the private config for the duration of the forward pass.

forward(hidden_states: torch.Tensor) torch.Tensor#

Forward function

class bridge.models.stepfun.step35_provider.Step35ModelProvider#

Bases: megatron.bridge.models.gpt_provider.GPTModelProvider

Model provider for Step-3.5-Flash.

Adds Step3.5-specific fields on top of GPTModelProvider:

  • layer_types: 0-indexed list of attention types (e.g. "full_attention" / "sliding_attention"). The provider may carry main decoder entries plus MTP entries because Step35DecoderLayer indexes MTP layers after config.num_layers.

  • attention_other_setting: HF dict that enables and describes the sliding-attention override.

  • sliding_attention_setting: normalized Megatron-facing shape overrides derived from attention_other_setting.

  • head_wise_attn_gate: whether to map HF’s per-head g_proj gate through Megatron-Core’s attention_output_gate path.

These fields are populated from the HF config inside Step35Bridge.provider_bridge.

layer_types: list[str] | None#

None

attention_other_setting: dict[str, Any] | None#

None

sliding_attention_setting: dict[str, Any] | None#

None

rotary_base_per_layer: list[float] | None#

None

head_wise_attn_gate: Optional[bool]#

False