bridge.models.minimax_m2.minimax_m2_provider#

MiniMax-M2 custom layer spec with full-dimension QK normalization.

MiniMax-M2 applies RMSNorm to the entire Q/K projection (weight shape = num_heads * head_dim) before splitting into heads. Megatron’s built-in QK norm applies per-head (weight shape = head_dim). This module bridges the gap by applying full-partition-dimension RMSNorm inside the standard SelfAttention flow.

Module Contents#

Classes#

_FullDimRMSNorm

RMSNorm applied across all attention heads (full Q/K dimension).

FullDimQNorm

Factory callable that creates a full-dimension RMSNorm for Q heads.

FullDimKNorm

Factory callable that creates a full-dimension RMSNorm for K heads.

Functions#

_get_tp_group

Lazy accessor for the TP process group (not available at module init time).

minimax_m2_layer_spec

Build a TE layer spec for MiniMax-M2 with full-dimension QK norm.

API#

class bridge.models.minimax_m2.minimax_m2_provider._FullDimRMSNorm(
local_dim: int,
global_dim: int,
tp_group_getter,
eps: float = 1e-06,
dtype: torch.dtype | None = None,
)#

Bases: torch.nn.Module

RMSNorm applied across all attention heads (full Q/K dimension).

Standard per-head QK norm normalizes over head_dim independently per head. This module normalizes over the full num_heads * head_dim dimension, matching HuggingFace models that use nn.RMSNorm(num_heads * head_dim) on the full Q/K vector before reshaping into heads.

With TP > 1 each rank holds only num_heads_per_partition heads, so the sum-of-squares is all-reduced across the TP group before computing the RMS. This keeps the normalization denominator identical to the single-GPU case.

Initialization

forward(x: torch.Tensor) torch.Tensor#
sharded_state_dict(
prefix: str = '',
sharded_offsets: Tuple[Tuple[int, int, int], ...] = (),
metadata: Optional[Dict] = None,
) Dict[str, ShardedTensor]#

Weight is TP-sharded along axis 0 (same as ColumnParallelLinear).

bridge.models.minimax_m2.minimax_m2_provider._get_tp_group()#

Lazy accessor for the TP process group (not available at module init time).

class bridge.models.minimax_m2.minimax_m2_provider.FullDimQNorm#

Factory callable that creates a full-dimension RMSNorm for Q heads.

Passed as q_layernorm in the layer spec. The SelfAttention constructor calls submodules.q_layernorm(hidden_size=head_dim, config=..., eps=...); this factory ignores the per-head hidden_size and computes the correct full partition dimension from config.

__new__(
hidden_size: int,
config: megatron.core.transformer.TransformerConfig,
eps: float = 1e-06,
)#
class bridge.models.minimax_m2.minimax_m2_provider.FullDimKNorm#

Factory callable that creates a full-dimension RMSNorm for K heads.

Same as FullDimQNorm but uses num_query_groups (GQA key-value heads) instead of num_attention_heads.

__new__(
hidden_size: int,
config: megatron.core.transformer.TransformerConfig,
eps: float = 1e-06,
)#
bridge.models.minimax_m2.minimax_m2_provider.minimax_m2_layer_spec(
config: GPTModelProvider,
) megatron.core.transformer.ModuleSpec#

Build a TE layer spec for MiniMax-M2 with full-dimension QK norm.

Starts from the standard TE MoE spec (which handles grouped-gemm experts, router, etc.) and replaces the per-head TENorm Q/K layernorm with FullDimQNorm / FullDimKNorm.