bridge.models.minimax_m2.minimax_m2_provider#
MiniMax-M2 custom layer spec with full-dimension QK normalization.
MiniMax-M2 applies RMSNorm to the entire Q/K projection (weight shape = num_heads * head_dim) before splitting into heads. Megatron’s built-in QK norm applies per-head (weight shape = head_dim). This module bridges the gap by applying full-partition-dimension RMSNorm inside the standard SelfAttention flow.
Module Contents#
Classes#
RMSNorm applied across all attention heads (full Q/K dimension). |
|
Factory callable that creates a full-dimension RMSNorm for Q heads. |
|
Factory callable that creates a full-dimension RMSNorm for K heads. |
Functions#
Lazy accessor for the TP process group (not available at module init time). |
|
Build a TE layer spec for MiniMax-M2 with full-dimension QK norm. |
API#
- class bridge.models.minimax_m2.minimax_m2_provider._FullDimRMSNorm(
- local_dim: int,
- global_dim: int,
- tp_group_getter,
- eps: float = 1e-06,
- dtype: torch.dtype | None = None,
Bases:
torch.nn.ModuleRMSNorm applied across all attention heads (full Q/K dimension).
Standard per-head QK norm normalizes over
head_dimindependently per head. This module normalizes over the fullnum_heads * head_dimdimension, matching HuggingFace models that usenn.RMSNorm(num_heads * head_dim)on the full Q/K vector before reshaping into heads.With TP > 1 each rank holds only
num_heads_per_partitionheads, so the sum-of-squares is all-reduced across the TP group before computing the RMS. This keeps the normalization denominator identical to the single-GPU case.Initialization
- forward(x: torch.Tensor) torch.Tensor#
- sharded_state_dict(
- prefix: str = '',
- sharded_offsets: Tuple[Tuple[int, int, int], ...] = (),
- metadata: Optional[Dict] = None,
Weight is TP-sharded along axis 0 (same as ColumnParallelLinear).
- bridge.models.minimax_m2.minimax_m2_provider._get_tp_group()#
Lazy accessor for the TP process group (not available at module init time).
- class bridge.models.minimax_m2.minimax_m2_provider.FullDimQNorm#
Factory callable that creates a full-dimension RMSNorm for Q heads.
Passed as
q_layernormin the layer spec. TheSelfAttentionconstructor callssubmodules.q_layernorm(hidden_size=head_dim, config=..., eps=...); this factory ignores the per-headhidden_sizeand computes the correct full partition dimension fromconfig.- __new__(
- hidden_size: int,
- config: megatron.core.transformer.TransformerConfig,
- eps: float = 1e-06,
- class bridge.models.minimax_m2.minimax_m2_provider.FullDimKNorm#
Factory callable that creates a full-dimension RMSNorm for K heads.
Same as
FullDimQNormbut usesnum_query_groups(GQA key-value heads) instead ofnum_attention_heads.- __new__(
- hidden_size: int,
- config: megatron.core.transformer.TransformerConfig,
- eps: float = 1e-06,
- bridge.models.minimax_m2.minimax_m2_provider.minimax_m2_layer_spec(
- config: GPTModelProvider,
Build a TE layer spec for MiniMax-M2 with full-dimension QK norm.
Starts from the standard TE MoE spec (which handles grouped-gemm experts, router, etc.) and replaces the per-head
TENormQ/K layernorm withFullDimQNorm/FullDimKNorm.