bridge.models.mimo_v2_flash.modeling_mimo_v2_flash#
MiMo-V2-Flash modeling building blocks.
Houses the custom modules used by MiMoV2FlashModelProvider:
MiMoV2FlashRotaryEmbedding: dual-base RoPE (local for SWA, global for full).MiMoV2FlashSelfAttention: per-layer KV head switching and asymmetric V head dim.MiMoV2FlashTEDotProductAttention: per-layer SWA window and learnable softmax for SWA layers (vanilla for full).MiMoV2FlashMTPSelfAttention/MiMoV2FlashMTPTEDotProductAttention: MTP variants (all MTP layers behave like SWA layers).mimo_v2_flash_layer_spec: GPT layer spec builder that injects the custom modules.
Module Contents#
Classes#
Dual-base rotary embeddings for MiMo-V2-Flash. This is the same pattern as Gemma3RotaryEmbedding. |
|
MiMo-V2-Flash self attention. |
|
MiMoV2Flash core attention. |
|
Overrides attention module for MTP |
|
Overrides core attention for MTP |
Functions#
Layer spec for MiMo-V2-Flash with custom hybrid attention modules. |
API#
- bridge.models.mimo_v2_flash.modeling_mimo_v2_flash._is_local_attn_layer(
- layer_number: int,
- hybrid_attention_pattern: List[int],
- class bridge.models.mimo_v2_flash.modeling_mimo_v2_flash.MiMoV2FlashRotaryEmbedding(
- rotary_base: int = 5000000,
- rotary_base_local: int = 10000,
- **kwargs,
Bases:
megatron.core.models.common.embeddings.rotary_pos_embedding.RotaryEmbeddingDual-base rotary embeddings for MiMo-V2-Flash. This is the same pattern as Gemma3RotaryEmbedding.
Initialization
- forward(
- max_seq_len: int,
- offset: int = 0,
- packed_seq: bool = False,
- cp_group: torch.distributed.ProcessGroup | None = None,
Get both local and global rope embeddings stacked [local, global].
- _forward_cached(
- max_seq_len: int,
- offset: int = 0,
- packed_seq: bool = False,
Cached forward for hashable parameters.
- class bridge.models.mimo_v2_flash.modeling_mimo_v2_flash.MiMoV2FlashSelfAttention(
- config: megatron.core.transformer.TransformerConfig,
- submodules: megatron.core.transformer.attention.SelfAttentionSubmodules,
- layer_number: int,
- *args,
- **kwargs,
Bases:
megatron.core.transformer.attention.SelfAttentionMiMo-V2-Flash self attention.
Customizations over standard SelfAttention:
Per-layer KV head count: SWA layers use swa_num_query_groups, full layers use full_attn_num_query_groups
Asymmetric V head dim: Q/K use qk_channels=192, V uses v_head_dim=128
Dual RoPE: local rope for SWA layers, global rope for full layers
Initialization
- get_query_key_value_tensors(
- hidden_states,
- key_value_states=None,
- **kwargs,
Split fused QKV with asymmetric V head dim.
- forward(
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor,
- key_value_states: Optional[torch.Tensor] = None,
- inference_context: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
- rotary_pos_emb: Optional[torch.Tensor] = None,
- rotary_pos_cos: Optional[torch.Tensor] = None,
- rotary_pos_sin: Optional[torch.Tensor] = None,
- rotary_pos_cos_sin: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
- attention_bias: Optional[torch.Tensor] = None,
- packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None,
- sequence_len_offset: Optional[int] = None,
- *,
- inference_params: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
Switch to either local or global rope embedding before forward
- class bridge.models.mimo_v2_flash.modeling_mimo_v2_flash.MiMoV2FlashTEDotProductAttention(
- config: megatron.core.transformer.TransformerConfig,
- layer_number: int,
- attn_mask_type: megatron.core.transformer.enums.AttnMaskType,
- attention_type: str,
- attention_dropout: Optional[float] = None,
- **kwargs,
Bases:
bridge.models.mimo_v2_flash.modeling_mimo_v2_flash.TEDotProductAttentionMiMoV2Flash core attention.
Switches between global and local sliding window attention based on the layer_number and pre-defined layer pattern. SWA layers use a learnable softmax (attention-sink bias); full-attention layers use vanilla softmax.
Initialization
- forward(query, key, value, attention_mask, attn_mask_type, **kwargs)#
- class bridge.models.mimo_v2_flash.modeling_mimo_v2_flash.MiMoV2FlashMTPSelfAttention(
- config,
- submodules,
- layer_number,
- *args,
- **kwargs,
Bases:
bridge.models.mimo_v2_flash.modeling_mimo_v2_flash.MiMoV2FlashSelfAttentionOverrides attention module for MTP
Initialization
- class bridge.models.mimo_v2_flash.modeling_mimo_v2_flash.MiMoV2FlashMTPTEDotProductAttention(
- config,
- layer_number,
- *args,
- **kwargs,
Bases:
bridge.models.mimo_v2_flash.modeling_mimo_v2_flash.MiMoV2FlashTEDotProductAttentionOverrides core attention for MTP
Initialization
- bridge.models.mimo_v2_flash.modeling_mimo_v2_flash.mimo_v2_flash_layer_spec(
- config,
Layer spec for MiMo-V2-Flash with custom hybrid attention modules.
Builds the block spec (handles MoE/dense split) then injects custom self-attention and core-attention modules into every layer spec.