bridge.models.mimo_v2_flash.modeling_mimo_v2_flash#

MiMo-V2-Flash modeling building blocks.

Houses the custom modules used by MiMoV2FlashModelProvider:

  • MiMoV2FlashRotaryEmbedding: dual-base RoPE (local for SWA, global for full).

  • MiMoV2FlashSelfAttention: per-layer KV head switching and asymmetric V head dim.

  • MiMoV2FlashTEDotProductAttention: per-layer SWA window and learnable softmax for SWA layers (vanilla for full).

  • MiMoV2FlashMTPSelfAttention / MiMoV2FlashMTPTEDotProductAttention: MTP variants (all MTP layers behave like SWA layers).

  • mimo_v2_flash_layer_spec: GPT layer spec builder that injects the custom modules.

Module Contents#

Classes#

MiMoV2FlashRotaryEmbedding

Dual-base rotary embeddings for MiMo-V2-Flash. This is the same pattern as Gemma3RotaryEmbedding.

MiMoV2FlashSelfAttention

MiMo-V2-Flash self attention.

MiMoV2FlashTEDotProductAttention

MiMoV2Flash core attention.

MiMoV2FlashMTPSelfAttention

Overrides attention module for MTP

MiMoV2FlashMTPTEDotProductAttention

Overrides core attention for MTP

Functions#

_is_local_attn_layer

mimo_v2_flash_layer_spec

Layer spec for MiMo-V2-Flash with custom hybrid attention modules.

API#

bridge.models.mimo_v2_flash.modeling_mimo_v2_flash._is_local_attn_layer(
layer_number: int,
hybrid_attention_pattern: List[int],
) bool#
class bridge.models.mimo_v2_flash.modeling_mimo_v2_flash.MiMoV2FlashRotaryEmbedding(
rotary_base: int = 5000000,
rotary_base_local: int = 10000,
**kwargs,
)#

Bases: megatron.core.models.common.embeddings.rotary_pos_embedding.RotaryEmbedding

Dual-base rotary embeddings for MiMo-V2-Flash. This is the same pattern as Gemma3RotaryEmbedding.

Initialization

forward(
max_seq_len: int,
offset: int = 0,
packed_seq: bool = False,
cp_group: torch.distributed.ProcessGroup | None = None,
) torch.Tensor#

Get both local and global rope embeddings stacked [local, global].

_forward_cached(
max_seq_len: int,
offset: int = 0,
packed_seq: bool = False,
) torch.Tensor#

Cached forward for hashable parameters.

class bridge.models.mimo_v2_flash.modeling_mimo_v2_flash.MiMoV2FlashSelfAttention(
config: megatron.core.transformer.TransformerConfig,
submodules: megatron.core.transformer.attention.SelfAttentionSubmodules,
layer_number: int,
*args,
**kwargs,
)#

Bases: megatron.core.transformer.attention.SelfAttention

MiMo-V2-Flash self attention.

Customizations over standard SelfAttention:

  • Per-layer KV head count: SWA layers use swa_num_query_groups, full layers use full_attn_num_query_groups

  • Asymmetric V head dim: Q/K use qk_channels=192, V uses v_head_dim=128

  • Dual RoPE: local rope for SWA layers, global rope for full layers

Initialization

get_query_key_value_tensors(
hidden_states,
key_value_states=None,
**kwargs,
)#

Split fused QKV with asymmetric V head dim.

forward(
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
inference_context: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
rotary_pos_emb: Optional[torch.Tensor] = None,
rotary_pos_cos: Optional[torch.Tensor] = None,
rotary_pos_sin: Optional[torch.Tensor] = None,
rotary_pos_cos_sin: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_bias: Optional[torch.Tensor] = None,
packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None,
sequence_len_offset: Optional[int] = None,
*,
inference_params: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
) Tuple[torch.Tensor, torch.Tensor]#

Switch to either local or global rope embedding before forward

class bridge.models.mimo_v2_flash.modeling_mimo_v2_flash.MiMoV2FlashTEDotProductAttention(
config: megatron.core.transformer.TransformerConfig,
layer_number: int,
attn_mask_type: megatron.core.transformer.enums.AttnMaskType,
attention_type: str,
attention_dropout: Optional[float] = None,
**kwargs,
)#

Bases: bridge.models.mimo_v2_flash.modeling_mimo_v2_flash.TEDotProductAttention

MiMoV2Flash core attention.

Switches between global and local sliding window attention based on the layer_number and pre-defined layer pattern. SWA layers use a learnable softmax (attention-sink bias); full-attention layers use vanilla softmax.

Initialization

forward(query, key, value, attention_mask, attn_mask_type, **kwargs)#
class bridge.models.mimo_v2_flash.modeling_mimo_v2_flash.MiMoV2FlashMTPSelfAttention(
config,
submodules,
layer_number,
*args,
**kwargs,
)#

Bases: bridge.models.mimo_v2_flash.modeling_mimo_v2_flash.MiMoV2FlashSelfAttention

Overrides attention module for MTP

Initialization

class bridge.models.mimo_v2_flash.modeling_mimo_v2_flash.MiMoV2FlashMTPTEDotProductAttention(
config,
layer_number,
*args,
**kwargs,
)#

Bases: bridge.models.mimo_v2_flash.modeling_mimo_v2_flash.MiMoV2FlashTEDotProductAttention

Overrides core attention for MTP

Initialization

bridge.models.mimo_v2_flash.modeling_mimo_v2_flash.mimo_v2_flash_layer_spec(
config,
) megatron.core.transformer.ModuleSpec#

Layer spec for MiMo-V2-Flash with custom hybrid attention modules.

Builds the block spec (handles MoE/dense split) then injects custom self-attention and core-attention modules into every layer spec.