# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional
[docs]
def calculate_mfu(tflops, world_size, time_seconds, reference_mfu=1979.0):
"""Calculate Model FLOPs Utilization (MFU).
Args:
tflops: TFLOPs per GPU
world_size: Total number of GPUs
time_seconds: Time taken for computation
reference_mfu: Peak TFLOPs of the hardware (default: H100)
Returns:
MFU as a percentage
"""
mfu = tflops / (world_size * time_seconds)
mfu = mfu / reference_mfu
return mfu * 100
[docs]
def gpt3_flops(config, gbs=1, seq_len=None):
"""Model FLOPs for GPT3 family - accepts either AutoConfig or normalized config"""
if seq_len is None:
seq_len = config.max_position_embeddings if hasattr(config, "max_position_embeddings") else 2048
hs = config.hidden_size
layers = config.num_hidden_layers
vocab_size = config.vocab_size
causal_self_attn = True
return (24 * gbs * seq_len * hs * hs + 4 * gbs * seq_len * seq_len * hs * (0.5 if causal_self_attn else 1)) * (
3 * layers
) + (6 * gbs * seq_len * hs * vocab_size)
[docs]
def llama2_flops(config, gbs=1, seq_len=None):
"""Model FLOPs for llama2 family - accepts either AutoConfig or normalized config"""
if seq_len is None:
seq_len = config.max_position_embeddings if hasattr(config, "max_position_embeddings") else 2048
layers = config.num_hidden_layers
hs = config.hidden_size
attention_heads = config.num_attention_heads
query_groups = config.num_key_value_heads if hasattr(config, "num_key_value_heads") else attention_heads
ffn_hs = config.intermediate_size
vocab_size = config.vocab_size
causal_self_attn = True
return (
gbs
* seq_len
* layers
* hs
* hs
* (
12
+ (12 * query_groups / attention_heads)
+ (18 * ffn_hs / hs)
+ (12 * seq_len / hs) * (0.5 if causal_self_attn else 1)
+ (6 * vocab_size / (layers * hs))
)
)
[docs]
def llama3_flops(config, gbs=1, seq_len=None):
"""Model FLOPs for llama3 family - accepts either AutoConfig or normalized config"""
if seq_len is None:
seq_len = config.max_position_embeddings if hasattr(config, "max_position_embeddings") else 2048
layers = config.num_hidden_layers
hs = config.hidden_size
attention_heads = config.num_attention_heads
query_groups = config.num_key_value_heads if hasattr(config, "num_key_value_heads") else attention_heads
ffn_hs = config.intermediate_size
vocab_size = config.vocab_size
causal_self_attn = True
return (
gbs
* seq_len
* layers
* hs
* hs
* (
12
+ (12 * query_groups / attention_heads)
+ (18 * ffn_hs / hs)
+ (12 * seq_len / hs) * (0.5 if causal_self_attn else 1)
+ (6 * vocab_size / (layers * hs))
)
)
[docs]
def nemotron_flops(config, gbs=1, seq_len=None):
"""Model FLOPs for nemotron family - accepts either AutoConfig or normalized config"""
if seq_len is None:
seq_len = config.max_position_embeddings if hasattr(config, "max_position_embeddings") else 2048
layers = config.num_hidden_layers
hs = config.hidden_size
attention_heads = config.num_attention_heads
query_groups = config.num_key_value_heads if hasattr(config, "num_key_value_heads") else attention_heads
ffn_hs = config.intermediate_size
vocab_size = config.vocab_size
causal_self_attn = True
return (
gbs
* seq_len
* layers
* hs
* hs
* (
12
+ (12 * query_groups / attention_heads)
+ (12 * ffn_hs / hs)
+ (12 * seq_len / hs) * (0.5 if causal_self_attn else 1)
+ (6 * vocab_size / (layers * hs))
)
)
[docs]
def mixtral_flops(config, gbs=1, seq_len=None):
"""Model FLOPs for mixtral family - accepts either AutoConfig or normalized config"""
if seq_len is None:
seq_len = config.max_position_embeddings if hasattr(config, "max_position_embeddings") else 2048
layers = config.num_hidden_layers
hs = config.hidden_size
attention_heads = config.num_attention_heads
query_groups = config.num_key_value_heads if hasattr(config, "num_key_value_heads") else attention_heads
ffn_hs = config.intermediate_size
vocab_size = config.vocab_size
moe_router_topk = config.num_experts_per_tok if hasattr(config, "num_experts_per_tok") else 2
causal_self_attn = True
return (
gbs
* seq_len
* layers
* hs
* hs
* (
12
+ (12 * query_groups / attention_heads)
+ (18 * moe_router_topk * ffn_hs / hs)
+ (12 * seq_len / hs) * (0.5 if causal_self_attn else 1)
+ (6 * vocab_size / (layers * hs))
)
)
[docs]
def qwen3_flops(config, gbs=1, seq_len=None):
"""Model FLOPs for Qwen3 family - accepts either AutoConfig or normalized config"""
# For VL composite configs, use the text_config sub-config
if hasattr(config, "text_config") and not hasattr(config, "num_hidden_layers"):
config = config.text_config
if seq_len is None:
seq_len = config.max_position_embeddings if hasattr(config, "max_position_embeddings") else 2048
layers = config.num_hidden_layers
hs = config.hidden_size
attention_heads = config.num_attention_heads
query_groups = config.num_key_value_heads if hasattr(config, "num_key_value_heads") else attention_heads
vocab_size = config.vocab_size
# Calculate head_dim if not present (for Qwen2) or use directly (for Qwen3)
head_dim = config.head_dim if hasattr(config, "head_dim") else (hs // attention_heads)
query_projection_to_hidden_size_ratio = (head_dim * attention_heads) / hs
# MoE fields - Qwen3 uses "moe_topk" if present, else dense (1)
moe_router_topk = config.num_experts_per_tok if hasattr(config, "num_experts_per_tok") else 1
moe_ffn_hidden_size = (
config.moe_intermediate_size if hasattr(config, "moe_intermediate_size") else config.intermediate_size
)
causal_self_attn = True
hidden_size = hs
gated_linear_multiplier = 2
# attention flops for GQA
attention_flops = (
3
* 2
* gbs
* layers
* seq_len
* hidden_size
* hidden_size
* query_projection_to_hidden_size_ratio
* (
(query_groups / attention_heads * 2 + 1) # QKV gemm
+ (seq_len / hidden_size * 2 * (0.5 if causal_self_attn else 1)) # attention
+ 1 # attention proj gemm
)
)
# mlp flops
mlp_flops = (
3
* 2
* gbs
* layers
* seq_len
* hidden_size
* (1 + gated_linear_multiplier)
* (moe_ffn_hidden_size * moe_router_topk) # MoE layers
)
# vocab flops
vocab_flops = 3 * 2 * gbs * seq_len * hidden_size * vocab_size
return attention_flops + mlp_flops + vocab_flops
[docs]
def bert_flops(config, gbs=1, seq_len=None):
"""Model FLOPs for BERT family - accepts either AutoConfig or normalized config"""
if seq_len is None:
seq_len = config.max_position_embeddings if hasattr(config, "max_position_embeddings") else 512
layers = config.num_hidden_layers
hs = config.hidden_size
vocab_size = config.vocab_size
return 72 * gbs * layers * seq_len * hs * hs * (1 + (seq_len / (6 * hs)) + (vocab_size / (12 * hs * layers)))
[docs]
def clip_vit_l_flops(config):
"""Model FLOPs for CLIP ViT"""
if config.img_seq_len is None:
config.img_seq_len = (config.img_h * config.img_w) / (
config.patch_dim * config.patch_dim
) + config.class_token_len
return config.gbs * config.layers * config.hs * config.hs * config.img_seq_len * (
24 + (4 * config.img_seq_len / config.hs)
) + (2 * config.gbs * config.hs * config.in_channels * config.img_h * config.img_w)
[docs]
def neva_projection_flops(config):
"""Model FLOPs for NeVA Projection"""
if "mlp" in config.projector_type:
return 6 * config.gbs * config.img_seq_len * config.ffn_hs * (config.inp_s + config.hs)
elif config.projector_type == "affine":
return 6 * config.gbs * config.img_seq_len * config.inp_s * config.hs
else:
raise ValueError(
f"NeVA Projections FLOPs calculator only supports 'mlp', 'mcore_mlp'"
f" or 'affine' projector_type but found {config.projector_type}"
)
[docs]
def flux_flops(config):
"""Model FLOPs for FLUX"""
hs = config.hs
seq_len = config.model_channels + config.inp_s
base_factor = 6 * config.gbs # common multiplier for most terms
# Joint layer computations
joint_layer_flops = (
base_factor
* config.layers[0]
* (
10 * hs * hs # hidden size operations
+ 2 * hs * (config.model_channels + config.inp_s) * (1 + hs * 7) # channel and context joint attention
+ 2 * (config.model_channels + config.inp_s) * hs # final projection
)
)
# Single layer computations
single_layer_flops = (
base_factor
* config.layers[1]
* seq_len
* hs
* (
3 # linear Y
+ 1 # Modulation
+ 4 * hs # Linear computations
+ (3 * hs + 2 * seq_len) # attention operations
+ 5 * hs # feed-forward
+ 1 # Modulation
)
)
# Embedding and projection layers
other_flops = base_factor * (
config.inp_s * config.in_channels * hs # image embedding
+ config.inp_s * hs * config.model_channels # text embedding
+ config.vec_in_dim * hs
+ hs * hs # vector embedding
+ 2 * (config.model_channels * hs + hs * hs) # guidance + timestep embedding
+ (config.inp_s * config.in_channels * hs) / config.gbs # final projection
)
return joint_layer_flops + single_layer_flops + other_flops
[docs]
def deepseekv3_flops(config, gbs=1, seq_len=None):
"""Model FLOPs for DeepSeek V3 - accepts either AutoConfig or normalized config"""
hs = config.hidden_size
layers = config.num_hidden_layers
attention_heads = config.num_attention_heads
ffn_hs = config.intermediate_size
vocab_size = config.vocab_size
# DeepSeek V3 specific fields
q_lora_rank = config.q_lora_rank if hasattr(config, "q_lora_rank") else None
kv_lora_rank = config.kv_lora_rank
qk_rope_head_dim = config.qk_rope_head_dim
qk_nope_head_dim = config.qk_nope_head_dim if hasattr(config, "qk_nope_head_dim") else None
v_head_dim = config.v_head_dim
# MoE fields
moe_intermediate_size = config.moe_intermediate_size
moe_shared_expert_intermediate_size = moe_intermediate_size
moe_ffn_hidden_size = moe_intermediate_size
moe_router_topk = config.num_experts_per_tok
# MoE layer pattern
first_k_dense_replace = config.first_k_dense_replace if hasattr(config, "first_k_dense_replace") else 0
if hasattr(config, "moe_layer_freq"):
moe_layer_freq = config.moe_layer_freq
else:
moe_layer_freq = [0] * first_k_dense_replace + [1] * (layers - first_k_dense_replace)
# MTP layers (optional)
mtp_num_layers = config.mtp_num_layers if hasattr(config, "mtp_num_layers") else None
# DSA / sparse attention (DeepSeek V3.2)
index_topk = getattr(config, "index_topk", None)
index_n_heads = getattr(config, "index_n_heads", 0)
index_head_dim = getattr(config, "index_head_dim", 0)
# self-attention flops
if index_topk is not None and index_topk > 0:
# Sparse: each query attends to index_topk keys
bmm1_flops = (qk_nope_head_dim + qk_rope_head_dim) * attention_heads * seq_len * index_topk
bmm2_flops = v_head_dim * attention_heads * seq_len * index_topk
else:
# Full causal
bmm1_flops = 0.5 * (qk_nope_head_dim + qk_rope_head_dim) * attention_heads * (seq_len**2)
bmm2_flops = 0.5 * v_head_dim * attention_heads * (seq_len**2)
per_input_attention_flops = 6 * (bmm1_flops + bmm2_flops) * layers
if mtp_num_layers is not None:
per_input_attention_flops += 6 * (bmm1_flops + bmm2_flops) * mtp_num_layers
# DSA indexer overhead (projections + full-sequence BMM per layer)
if index_topk is not None and index_topk > 0 and index_n_heads > 0:
idx_proj_params = (q_lora_rank or 0) * index_n_heads * index_head_dim + hs * index_head_dim + hs * index_n_heads
idx_bmm = index_n_heads * index_head_dim * seq_len * seq_len
per_layer_indexer = 6 * (idx_proj_params * seq_len + idx_bmm)
total_indexer_layers = layers + (mtp_num_layers or 0)
per_input_attention_flops += per_layer_indexer * total_indexer_layers
# linear layer flops
if q_lora_rank is not None:
per_layer_mla_params = hs * q_lora_rank + q_lora_rank * (
(qk_nope_head_dim + qk_rope_head_dim) * attention_heads
) # Q
else:
per_layer_mla_params = hs * ((qk_nope_head_dim + qk_rope_head_dim) * attention_heads) # Q
per_layer_mla_params += hs * qk_rope_head_dim # K^R
per_layer_mla_params += hs * kv_lora_rank + kv_lora_rank * (
(qk_nope_head_dim + v_head_dim) * attention_heads
) # K^C and V^C
per_layer_mla_params += v_head_dim * attention_heads * hs # Proj
mla_params = per_layer_mla_params * layers
if mtp_num_layers is not None:
mla_params += per_layer_mla_params * mtp_num_layers
dense_layer_ffn_params = hs * ffn_hs * 3 # gated linear unit
per_shared_expert_params = hs * moe_shared_expert_intermediate_size * 3
per_selected_expert_params = hs * moe_ffn_hidden_size * 3
ffn_params = 0
if isinstance(moe_layer_freq, int):
moe_layer_pattern = [1 if (i % moe_layer_freq == 0) else 0 for i in range(layers)]
else:
moe_layer_pattern = moe_layer_freq
for i in moe_layer_pattern:
if i == 0:
ffn_params += dense_layer_ffn_params
else:
ffn_params += per_shared_expert_params + (per_selected_expert_params * moe_router_topk)
if mtp_num_layers is not None:
for i in range(mtp_num_layers):
ffn_params += per_shared_expert_params + (per_selected_expert_params * moe_router_topk)
per_input_params = mla_params + ffn_params
per_input_linear_flops = 6 * per_input_params * seq_len
# vocab flops
per_input_vocab_flops = 6 * vocab_size * hs * seq_len
if mtp_num_layers is not None:
for i in range(mtp_num_layers):
per_input_vocab_flops += 6 * vocab_size * hs * seq_len
per_input_vocab_flops += 6 * hs * 2 * hs * seq_len
return (per_input_attention_flops + per_input_linear_flops + per_input_vocab_flops) * gbs
[docs]
def _nemotronh_mlp_layer_flops(config, gbs, seq_len):
"""Model FLOPs for MLP layer. Assume gated linear unit."""
return 6 * gbs * seq_len * config.hidden_size * config.intermediate_size * 3
[docs]
def _nemotronh_moe_layer_flops(config, gbs, seq_len):
"""Model FLOPs for a MoE layer in Nemotron V3/Super V3 (hybrid Mamba/Attention/MoE).
Nemotron V3 uses relu2 (non-gated) for both routed and shared experts,
so each expert has 2 linear projections (up_proj + down_proj), not 3.
When moe_latent_size is set (Super V3), routed experts operate in a reduced
latent space with additional projection layers (fc1_latent_proj, fc2_latent_proj).
The shared expert and gate always operate in the full hidden_size dimension.
Accounts for:
1. Routed experts: only num_experts_per_tok activated per token.
2. Shared expert: always active for every token (full hidden_size).
3. Router/gate: linear projection hidden_size -> n_routed_experts.
4. Latent projections (if moe_latent_size is set): down and up projections.
"""
hs = config.hidden_size
num_tokens = gbs * seq_len
# Determine if latent MoE is used
moe_latent_size = getattr(config, "moe_latent_size", None)
if moe_latent_size is not None:
# Latent MoE: experts operate in reduced latent space
expert_dim = moe_latent_size
# fc1_latent_proj (hs -> latent) + fc2_latent_proj (latent -> hs)
latent_proj_flops = 6 * num_tokens * hs * moe_latent_size * 2
else:
expert_dim = hs
latent_proj_flops = 0
# Routed experts: num_experts_per_tok activated, each up_proj + down_proj
routed_expert_flops = 6 * num_tokens * config.num_experts_per_tok * expert_dim * config.moe_intermediate_size * 2
# Shared expert: always active on full hidden_size, up_proj + down_proj
shared_expert_flops = 6 * num_tokens * hs * config.moe_shared_expert_intermediate_size * 2
# Router/gate: hidden_size -> n_routed_experts (always full dimension)
gate_flops = 6 * num_tokens * hs * config.n_routed_experts
return routed_expert_flops + shared_expert_flops + gate_flops + latent_proj_flops
[docs]
def _non_mla_attn_layer_flops(config, gbs, seq_len):
"""Model FLOPs for attention layer"""
hs = config.hidden_size
attention_heads = config.num_attention_heads
query_groups = config.num_key_value_heads if hasattr(config, "num_key_value_heads") else attention_heads
return (
6
* gbs
* seq_len
* hs
* (
hs # Q
+ query_groups / attention_heads * hs * 2 # KV
+ seq_len / 2 * 2
+ hs
)
)
[docs]
def _mamba_layer_flops(config, gbs, seq_len):
"""Model FLOPs for Mamba layer.
Three components:
- in_proj: input projections (x_proj, z_proj, dt_proj, B_proj, C_proj)
- scan: SSM scan kernel (7x factor accounts for the full SSD scan cost)
- out_proj: output projection back to hidden_size
Multiplied by 6 (3x fwd+bwd * 2x FMA) for in_proj/out_proj (standard GEMMs),
and 7 * 3 = 21 for scan (non-GEMM kernel, higher op count per element).
"""
hs = config.hidden_size
if hasattr(config, "mamba_state_dim"):
mamba_state_dim = config.mamba_state_dim
elif hasattr(config, "ssm_state_size"):
mamba_state_dim = config.ssm_state_size
else:
raise ValueError("Expected config to have 'mamba_state_dim' or 'ssm_state_size'")
mamba_head_dim = config.mamba_head_dim
if hasattr(config, "mamba_num_groups"):
mamba_num_groups = config.mamba_num_groups
elif hasattr(config, "n_groups"):
mamba_num_groups = config.n_groups
else:
raise ValueError("Expected config to have 'mamba_num_groups' or 'n_groups'")
if hasattr(config, "mamba_num_heads") and config.mamba_num_heads:
nheads = config.mamba_num_heads
else:
nheads = 2 * hs // mamba_head_dim # default expand is 2
d_in = nheads * mamba_head_dim
in_proj = 6 * gbs * seq_len * hs * (2 * d_in + 2 * mamba_num_groups * mamba_state_dim + nheads)
scan = 7 * 3 * gbs * seq_len * d_in * mamba_state_dim
out_proj = 6 * gbs * seq_len * d_in * hs
return in_proj + scan + out_proj
[docs]
def _hybrid_model_flops(config, gbs, seq_len):
"""Model FLOPs for hybrid model"""
if hasattr(config, "is_hybrid_model"):
if not config.is_hybrid_model:
raise ValueError("Config must have is_hybrid_model=True")
elif not hasattr(config, "hybrid_override_pattern"):
raise ValueError("Expected config to have `is_hybrid_model` or `hybrid_override_pattern`")
hybrid_override_pattern = config.hybrid_override_pattern
hs = config.hidden_size
vocab_size = config.vocab_size
num_attn_layers, num_mamba_layers, num_mlp_layers, num_moe_layers = 0, 0, 0, 0
for c in hybrid_override_pattern:
if c == "M":
num_mamba_layers += 1
elif c == "-":
num_mlp_layers += 1
elif c == "*":
num_attn_layers += 1
elif c == "E":
num_moe_layers += 1
total = 6 * gbs * seq_len * hs * vocab_size
if num_attn_layers:
total += num_attn_layers * _non_mla_attn_layer_flops(config, gbs, seq_len)
if num_mamba_layers:
total += num_mamba_layers * _mamba_layer_flops(config, gbs, seq_len)
if num_mlp_layers:
total += num_mlp_layers * _nemotronh_mlp_layer_flops(config, gbs, seq_len)
if num_moe_layers:
total += num_moe_layers * _nemotronh_moe_layer_flops(config, gbs, seq_len)
return total
[docs]
def nemotronh_flops(config, gbs=1, seq_len=None):
"""Model FLOPs for NemotronH"""
if seq_len is None:
seq_len = config.max_position_embeddings if hasattr(config, "max_position_embeddings") else 2048
return _hybrid_model_flops(config, gbs, seq_len)
[docs]
def attention_flops_calculator(
seqlen,
hidden_size,
num_attention_heads,
num_query_groups,
kv_channels: Optional[int] = None,
is_swa: bool = False,
swa_window_size: int = 128,
):
"""Calculate the flops for the attention part."""
kv_channels = kv_channels or (hidden_size // num_attention_heads)
linear_qkv = seqlen * hidden_size * (kv_channels * (num_attention_heads + num_query_groups * 2))
linear_proj = seqlen * hidden_size * (kv_channels * num_attention_heads)
if is_swa:
attention_mask_nz_elem = (
swa_window_size * (swa_window_size + 1) / 2 + (seqlen - swa_window_size) * swa_window_size
)
attention = num_attention_heads * (attention_mask_nz_elem * kv_channels) * 2
else:
bmm_k = kv_channels
bmm_b = num_attention_heads
attention_mask_nz_elem = seqlen * (seqlen + 1) / 2
attention = bmm_b * attention_mask_nz_elem * bmm_k * 2
return (linear_qkv + linear_proj + attention) * 6
[docs]
def moe_mlp_flops_calculator(
seqlen,
hidden_size,
moe_ffn_hidden_size,
moe_router_topk,
gated_linear_unit: bool = True,
):
"""Calculate the flops for the MLP"""
total_num_tokens = seqlen * moe_router_topk
linear_fc1 = total_num_tokens * hidden_size * moe_ffn_hidden_size * (2 if gated_linear_unit else 1)
linear_fc2 = total_num_tokens * moe_ffn_hidden_size * hidden_size
return (linear_fc1 + linear_fc2) * 6
[docs]
def loss_flops_calculator(
seqlen,
hidden_size,
vocab_size,
):
"""Calculate the flops for the loss"""
return (seqlen * hidden_size * vocab_size) * 6
[docs]
def gpt_oss_flops_calculator(
gbs,
num_layers,
seqlen,
hidden_size,
num_attention_heads,
num_query_groups,
moe_ffn_hidden_size,
moe_router_topk,
vocab_size,
kv_channels: Optional[int] = None,
swa_window_size: int = 128,
window_attn_skip_freq: Optional[int] = 2,
):
"""Calculate the flops for the GPT-OSS model"""
flops = 0
for i in range(num_layers):
if i % window_attn_skip_freq == 0:
flops += attention_flops_calculator(
seqlen,
hidden_size,
num_attention_heads,
num_query_groups,
kv_channels,
is_swa=False,
)
else:
flops += attention_flops_calculator(
seqlen,
hidden_size,
num_attention_heads,
num_query_groups,
kv_channels,
is_swa=True,
swa_window_size=swa_window_size,
)
flops += moe_mlp_flops_calculator(
seqlen,
hidden_size,
moe_ffn_hidden_size,
moe_router_topk,
)
flops += loss_flops_calculator(seqlen, hidden_size, vocab_size)
flops *= gbs
return flops
[docs]
def gpt_oss_flops(config, gbs=1, seq_len=None):
"""Model FLOPs for GPT-OSS"""
# Map config fields
num_layers = config.num_hidden_layers
hidden_size = config.hidden_size
num_attention_heads = config.num_attention_heads
num_query_groups = config.num_key_value_heads if hasattr(config, "num_key_value_heads") else num_attention_heads
vocab_size = config.vocab_size
# GPT-OSS specific fields
moe_ffn_hidden_size = (
config.moe_ffn_hidden_size if hasattr(config, "moe_ffn_hidden_size") else config.intermediate_size
)
moe_router_topk = config.num_experts_per_tok
kv_channels = config.kv_channels if hasattr(config, "kv_channels") else (hidden_size // num_attention_heads)
swa_window_size = config.window_size[0] if hasattr(config, "window_size") and config.window_size else 128
window_attn_skip_freq = config.window_attn_skip_freq if hasattr(config, "window_attn_skip_freq") else 2
return gpt_oss_flops_calculator(
gbs=gbs,
num_layers=num_layers,
seqlen=seq_len,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_query_groups=num_query_groups,
moe_ffn_hidden_size=moe_ffn_hidden_size,
moe_router_topk=moe_router_topk,
vocab_size=vocab_size,
kv_channels=kv_channels,
swa_window_size=swa_window_size,
window_attn_skip_freq=window_attn_skip_freq,
)
[docs]
def glm4_moe_flops(config, gbs=1, seq_len=None):
if seq_len is None:
seq_len = config.max_position_embeddings if hasattr(config, "max_position_embeddings") else 2048
layers = config.num_hidden_layers
hs = config.hidden_size
attention_heads = config.num_attention_heads
query_groups = config.num_key_value_heads if hasattr(config, "num_key_value_heads") else attention_heads
vocab_size = config.vocab_size
# GLM4 MoE attention config
head_dim = getattr(config, "head_dim", hs // attention_heads)
query_projection_to_hidden_size_ratio = (head_dim * attention_heads) / hs
# MoE config
ffn_hs = config.intermediate_size # for dense layers
moe_intermediate_size = config.moe_intermediate_size if hasattr(config, "moe_intermediate_size") else ffn_hs
moe_router_topk = config.num_experts_per_tok if hasattr(config, "num_experts_per_tok") else 1
n_shared_experts = config.n_shared_experts if hasattr(config, "n_shared_experts") else 0
first_k_dense_replace = config.first_k_dense_replace if hasattr(config, "first_k_dense_replace") else 0
causal_self_attn = True
hidden_size = hs
gated_linear_multiplier = 2 # SwiGLU
# Attention flops for GQA (Qwen3-style)
attention_flops = (
3
* 2
* gbs
* layers
* seq_len
* hidden_size
* hidden_size
* query_projection_to_hidden_size_ratio
* (
(query_groups / attention_heads * 2 + 1) # QKV gemm
+ (seq_len / hidden_size * 2 * (0.5 if causal_self_attn else 1)) # attention
+ 1 # attention proj gemm
)
)
# MLP flops (DeepSeek V3-style MoE)
# Dense layers: first_k_dense_replace layers
dense_mlp_flops = (
3 * 2 * gbs * first_k_dense_replace * seq_len * hidden_size * (1 + gated_linear_multiplier) * ffn_hs
)
# MoE layers: (layers - first_k_dense_replace) layers
# Each MoE layer has: shared experts + routed experts (topk selected)
num_moe_layers = layers - first_k_dense_replace
# Shared expert flops (always computed)
shared_expert_flops = (
3
* 2
* gbs
* num_moe_layers
* seq_len
* hidden_size
* (1 + gated_linear_multiplier)
* (moe_intermediate_size * n_shared_experts)
)
# Routed expert flops (topk selected)
routed_expert_flops = (
3
* 2
* gbs
* num_moe_layers
* seq_len
* hidden_size
* (1 + gated_linear_multiplier)
* (moe_intermediate_size * moe_router_topk)
)
mlp_flops = dense_mlp_flops + shared_expert_flops + routed_expert_flops
# Vocab flops
vocab_flops = 3 * 2 * gbs * seq_len * hidden_size * vocab_size
return attention_flops + mlp_flops + vocab_flops
[docs]
def minimax_m2_flops(config, gbs=1, seq_len=None):
"""Model FLOPs for MiniMax-M2 family - accepts either AutoConfig or normalized config.
Architecture: GQA attention (Q/K/V/O separate projections, head_dim may differ from
hidden_size // num_heads) + MoE with SwiGLU (no shared experts by default).
Optionally includes MTP (Multi-Token Prediction) modules gated by use_mtp.
"""
if seq_len is None:
seq_len = config.max_position_embeddings if hasattr(config, "max_position_embeddings") else 2048
layers = config.num_hidden_layers
hs = config.hidden_size
attention_heads = config.num_attention_heads
query_groups = config.num_key_value_heads if hasattr(config, "num_key_value_heads") else attention_heads
vocab_size = config.vocab_size
head_dim = getattr(config, "head_dim", hs // attention_heads)
query_projection_to_hidden_size_ratio = (head_dim * attention_heads) / hs
# MoE config — all layers are MoE, no shared experts by default
ffn_hs = config.intermediate_size
moe_router_topk = config.num_experts_per_tok if hasattr(config, "num_experts_per_tok") else 8
shared_intermediate_size = getattr(config, "shared_intermediate_size", 0)
# MTP config (optional, gated by use_mtp)
use_mtp = getattr(config, "use_mtp", False)
num_mtp_modules = getattr(config, "num_mtp_modules", 0) if use_mtp else 0
mtp_transformer_layers = getattr(config, "mtp_transformer_layers", 1)
causal_self_attn = True
gated_linear_multiplier = 2 # SwiGLU: gate + up projections
# --- Attention flops (GQA with separate Q/K/V/O projections) ---
def _attention_flops_per_layer():
return (
6
* gbs
* seq_len
* hs
* hs
* query_projection_to_hidden_size_ratio
* (
(query_groups / attention_heads * 2 + 1) # QKV gemm
+ (seq_len / hs * 2 * (0.5 if causal_self_attn else 1)) # BMM (causal)
+ 1 # output proj gemm
)
)
attention_flops = _attention_flops_per_layer() * layers
# --- MoE MLP flops (SwiGLU, all layers) ---
def _moe_mlp_flops_per_layer():
# Routed experts (topk selected)
routed = 6 * gbs * seq_len * hs * (1 + gated_linear_multiplier) * (ffn_hs * moe_router_topk)
# Shared experts (if any)
shared = (
6 * gbs * seq_len * hs * (1 + gated_linear_multiplier) * shared_intermediate_size
if shared_intermediate_size > 0
else 0
)
return routed + shared
mlp_flops = _moe_mlp_flops_per_layer() * layers
# --- Vocab flops (lm_head) ---
vocab_flops = 6 * gbs * seq_len * hs * vocab_size
# --- MTP module flops (optional) ---
mtp_flops = 0
if num_mtp_modules > 0:
total_mtp_layers = num_mtp_modules * mtp_transformer_layers
# Embedding projection per module: concat(hidden, next_embed) -> hidden (2*hs -> hs)
mtp_flops += 6 * gbs * seq_len * hs * 2 * hs * num_mtp_modules
# Transformer layers (attention + MoE MLP)
mtp_flops += _attention_flops_per_layer() * total_mtp_layers
mtp_flops += _moe_mlp_flops_per_layer() * total_mtp_layers
# Vocab projection per module
mtp_flops += 6 * gbs * seq_len * hs * vocab_size * num_mtp_modules
return attention_flops + mlp_flops + vocab_flops + mtp_flops
[docs]
def _gdn_attention_per_layer_flops(
gbs,
seq_len,
hidden_size,
linear_key_head_dim,
linear_value_head_dim,
linear_num_key_heads,
linear_num_value_heads,
linear_conv_kernel_dim,
):
"""FLOPs for a single Gated DeltaNet (GDN / linear attention) layer.
Based on the GDN FLOPs calculator from Megatron-Bridge PR #2925.
"""
qk_dim = linear_key_head_dim * linear_num_key_heads
v_dim = linear_value_head_dim * linear_num_value_heads
return (
3
* 2
* gbs
* seq_len
* (
hidden_size * (2 * qk_dim + 2 * v_dim + 2 * linear_num_value_heads)
+ linear_conv_kernel_dim * (2 * qk_dim + v_dim)
+ linear_num_value_heads * (linear_value_head_dim**2) * 4
+ hidden_size * v_dim
)
)
[docs]
def qwen3_5_flops(config, gbs=1, seq_len=None):
"""Model FLOPs for Qwen3.5 family (MoE and Dense) with hybrid GDN/full attention.
Qwen3.5 uses a hybrid attention pattern: 75% GDN (linear attention) layers
and 25% standard GQA (full attention) layers (full_attention_interval=4).
Supports both the MoE variant (Qwen3.5-35B-A3B) and Dense variant (Qwen3.5-27B).
"""
# For VL composite configs, use the text_config sub-config
if hasattr(config, "text_config") and not hasattr(config, "num_hidden_layers"):
config = config.text_config
if seq_len is None:
seq_len = config.max_position_embeddings if hasattr(config, "max_position_embeddings") else 2048
layers = config.num_hidden_layers
hs = config.hidden_size
attention_heads = config.num_attention_heads
query_groups = config.num_key_value_heads if hasattr(config, "num_key_value_heads") else attention_heads
vocab_size = config.vocab_size
head_dim = getattr(config, "head_dim", hs // attention_heads)
# GDN (linear attention) parameters
linear_key_head_dim = config.linear_key_head_dim
linear_value_head_dim = config.linear_value_head_dim
linear_num_key_heads = config.linear_num_key_heads
linear_num_value_heads = config.linear_num_value_heads
linear_conv_kernel_dim = getattr(config, "linear_conv_kernel_dim", 4)
# Determine layer counts from layer_types or full_attention_interval
if hasattr(config, "layer_types") and config.layer_types:
layer_types = config.layer_types
num_full_attn_layers = sum(1 for lt in layer_types if lt == "full_attention")
num_gdn_layers = layers - num_full_attn_layers
else:
full_attention_interval = getattr(config, "full_attention_interval", 4)
num_full_attn_layers = layers // full_attention_interval
num_gdn_layers = layers - num_full_attn_layers
# MoE fields
is_moe = hasattr(config, "num_experts") and config.num_experts is not None and config.num_experts > 1
moe_router_topk = getattr(config, "num_experts_per_tok", 1) if is_moe else 1
moe_intermediate_size = getattr(config, "moe_intermediate_size", 0) if is_moe else 0
shared_expert_intermediate_size = getattr(config, "shared_expert_intermediate_size", 0) if is_moe else 0
ffn_hs = getattr(config, "intermediate_size", 0) if not is_moe else 0
# MTP layers
mtp_num_layers = getattr(config, "mtp_num_hidden_layers", 0) or 0
causal_self_attn = True
gated_linear_multiplier = 2 # SwiGLU: gate + up projections
query_projection_to_hidden_size_ratio = (head_dim * attention_heads) / hs
# Qwen3.5 uses gated attention: Q proj outputs 2x (query + gate), applied as sigmoid(gate)*attn
attn_output_gate = getattr(config, "attn_output_gate", True)
q_gate_multiplier = 2 if attn_output_gate else 1
# --- Standard (full) attention flops per layer ---
full_attn_per_layer = (
6
* gbs
* seq_len
* hs
* hs
* query_projection_to_hidden_size_ratio
* (
(query_groups / attention_heads * 2 + q_gate_multiplier) # QKV gemm (Q is 2x with gate)
+ (seq_len / hs * 2 * (0.5 if causal_self_attn else 1)) # attention BMM
+ 1 # output proj gemm
)
)
# --- GDN (linear attention) flops per layer ---
gdn_attn_per_layer = _gdn_attention_per_layer_flops(
gbs,
seq_len,
hs,
linear_key_head_dim,
linear_value_head_dim,
linear_num_key_heads,
linear_num_value_heads,
linear_conv_kernel_dim,
)
# Total attention flops
attention_flops = full_attn_per_layer * num_full_attn_layers + gdn_attn_per_layer * num_gdn_layers
# --- MLP flops ---
if is_moe:
# Routed experts (topk selected) + shared experts, all layers are MoE
routed_expert_flops = (
6 * gbs * layers * seq_len * hs * (1 + gated_linear_multiplier) * (moe_intermediate_size * moe_router_topk)
)
shared_expert_flops = (
6 * gbs * layers * seq_len * hs * (1 + gated_linear_multiplier) * shared_expert_intermediate_size
)
mlp_flops = routed_expert_flops + shared_expert_flops
else:
# Dense MLP with SwiGLU
mlp_flops = 6 * gbs * layers * seq_len * hs * (1 + gated_linear_multiplier) * ffn_hs
# --- Vocab flops ---
vocab_flops = 6 * gbs * seq_len * hs * vocab_size
# --- MTP flops ---
mtp_flops = 0
if mtp_num_layers > 0:
# Embedding projection per MTP layer: 2*hs -> hs
mtp_flops += 6 * gbs * seq_len * hs * 2 * hs * mtp_num_layers
# MTP layers reuse the last transformer layer pattern (assumed full attention)
mtp_flops += full_attn_per_layer * mtp_num_layers
# MTP MLP (same as main model's last layer)
if is_moe:
mtp_mlp_per_layer = (
6
* gbs
* seq_len
* hs
* (1 + gated_linear_multiplier)
* (moe_intermediate_size * moe_router_topk + shared_expert_intermediate_size)
)
else:
mtp_mlp_per_layer = 6 * gbs * seq_len * hs * (1 + gated_linear_multiplier) * ffn_hs
mtp_flops += mtp_mlp_per_layer * mtp_num_layers
# Vocab projection per MTP layer
mtp_flops += 6 * gbs * seq_len * hs * vocab_size * mtp_num_layers
return attention_flops + mlp_flops + vocab_flops + mtp_flops
# ---------------------------------------------------------------------------
# Shared helpers for MLA (Multi-Latent Attention) + MoE models
# ---------------------------------------------------------------------------
[docs]
def _mla_attention_per_layer_flops(
gbs,
seq_len,
hs,
attention_heads,
q_lora_rank,
kv_lora_rank,
qk_rope_head_dim,
qk_nope_head_dim,
v_head_dim,
index_topk=None,
index_n_heads=0,
index_head_dim=0,
):
"""Per-layer FLOPs for Multi-Latent Attention (MLA).
Shared by DeepSeek V3, Kimi K2.5, Mistral Small 4, GLM-5, etc.
When index_topk is set (DSA / sparse attention), accounts for:
- Sparse main attention BMM: S * index_topk instead of 0.5 * S^2
- DSA indexer overhead: Q/K/weights projections + full S^2 indexer BMM
"""
# --- Main MLA attention BMM ---
if index_topk is not None and index_topk > 0:
# Sparse attention: each query attends to index_topk keys (not full causal)
bmm1 = (qk_nope_head_dim + qk_rope_head_dim) * attention_heads * seq_len * index_topk
bmm2 = v_head_dim * attention_heads * seq_len * index_topk
else:
# Full causal attention
bmm1 = 0.5 * (qk_nope_head_dim + qk_rope_head_dim) * attention_heads * (seq_len**2)
bmm2 = 0.5 * v_head_dim * attention_heads * (seq_len**2)
bmm_flops = 6 * gbs * (bmm1 + bmm2)
# --- MLA linear projections ---
if q_lora_rank is not None:
q_params = hs * q_lora_rank + q_lora_rank * ((qk_nope_head_dim + qk_rope_head_dim) * attention_heads)
else:
q_params = hs * ((qk_nope_head_dim + qk_rope_head_dim) * attention_heads)
kr_params = hs * qk_rope_head_dim
kv_params = hs * kv_lora_rank + kv_lora_rank * ((qk_nope_head_dim + v_head_dim) * attention_heads)
o_params = v_head_dim * attention_heads * hs
linear_flops = 6 * gbs * seq_len * (q_params + kr_params + kv_params + o_params)
# --- DSA indexer overhead ---
indexer_flops = 0
if index_topk is not None and index_topk > 0 and index_n_heads > 0:
# Indexer projections: wq_b (q_lora -> idx_heads*idx_hd),
# wk (hs -> idx_hd), weights_proj (hs -> idx_heads)
idx_proj_params = (
(q_lora_rank or 0) * index_n_heads * index_head_dim # wq_b
+ hs * index_head_dim # wk
+ hs * index_n_heads # weights_proj
)
# Indexer full-sequence BMM: Q@K^T over all positions to find top-k
idx_bmm = index_n_heads * index_head_dim * seq_len * seq_len
indexer_flops = 6 * gbs * (idx_proj_params * seq_len + idx_bmm)
return bmm_flops + linear_flops + indexer_flops
[docs]
def _mla_moe_model_flops(
gbs,
seq_len,
hs,
layers,
attention_heads,
vocab_size,
q_lora_rank,
kv_lora_rank,
qk_rope_head_dim,
qk_nope_head_dim,
v_head_dim,
dense_ffn_hs,
moe_ffn_hs,
moe_router_topk,
moe_shared_expert_hs,
moe_layer_pattern,
mtp_num_layers=0,
index_topk=None,
index_n_heads=0,
index_head_dim=0,
):
"""FLOPs for MLA + MoE transformer models (DeepSeek-V3 style).
Args:
moe_layer_pattern: List of 0/1 per layer (0=dense, 1=MoE).
moe_shared_expert_hs: Total intermediate size for all shared experts combined.
index_topk: If set, use DSA sparse attention with this many selected positions.
index_n_heads: Number of heads in the DSA indexer.
index_head_dim: Head dimension of the DSA indexer.
"""
# --- Attention (MLA on every layer) ---
mla_per_layer = _mla_attention_per_layer_flops(
gbs,
seq_len,
hs,
attention_heads,
q_lora_rank,
kv_lora_rank,
qk_rope_head_dim,
qk_nope_head_dim,
v_head_dim,
index_topk=index_topk,
index_n_heads=index_n_heads,
index_head_dim=index_head_dim,
)
attention_flops = mla_per_layer * layers
# --- FFN (dense or MoE with shared experts, SwiGLU = 3 projections) ---
dense_layer_ffn_params = hs * dense_ffn_hs * 3
per_shared_expert_params = hs * moe_shared_expert_hs * 3
per_selected_expert_params = hs * moe_ffn_hs * 3
ffn_params = 0
for is_moe in moe_layer_pattern:
if is_moe == 0:
ffn_params += dense_layer_ffn_params
else:
ffn_params += per_shared_expert_params + (per_selected_expert_params * moe_router_topk)
ffn_flops = 6 * gbs * seq_len * ffn_params
# --- Vocab ---
vocab_flops = 6 * gbs * seq_len * hs * vocab_size
# --- MTP ---
mtp_flops = 0
if mtp_num_layers > 0:
mtp_flops += mla_per_layer * mtp_num_layers
last_is_moe = moe_layer_pattern[-1] if moe_layer_pattern else 0
if last_is_moe:
mtp_ffn_params = per_shared_expert_params + (per_selected_expert_params * moe_router_topk)
else:
mtp_ffn_params = dense_layer_ffn_params
mtp_flops += 6 * gbs * seq_len * mtp_ffn_params * mtp_num_layers
mtp_flops += 6 * gbs * seq_len * hs * vocab_size * mtp_num_layers
mtp_flops += 6 * gbs * seq_len * hs * 2 * hs * mtp_num_layers # embedding projection
return attention_flops + ffn_flops + vocab_flops + mtp_flops
[docs]
def _build_moe_layer_pattern(config, layers):
"""Build a list of 0/1 indicating dense(0) vs MoE(1) per layer.
Handles multiple config styles: first_k_dense_replace + moe_layer_freq,
mlp_layer_types list, etc.
"""
mlp_layer_types = getattr(config, "mlp_layer_types", None)
if mlp_layer_types is not None:
return [0 if lt == "dense" else 1 for lt in mlp_layer_types]
first_k_dense = getattr(config, "first_k_dense_replace", 0)
moe_layer_freq = getattr(config, "moe_layer_freq", 1)
if isinstance(moe_layer_freq, list):
return moe_layer_freq
return [0] * first_k_dense + [
1 if ((i - first_k_dense) % moe_layer_freq == 0) else 0 for i in range(first_k_dense, layers)
]
[docs]
def mla_moe_flops(config, gbs=1, seq_len=None):
"""Model FLOPs for MLA + MoE models (Kimi K2, GLM-5, Mistral Small 4, etc.).
Handles VL wrappers by extracting text_config if present.
"""
# Handle VL wrappers with nested text_config
cfg = config
if hasattr(config, "text_config") and not hasattr(config, "num_hidden_layers"):
cfg = config.text_config
if seq_len is None:
seq_len = getattr(cfg, "max_position_embeddings", 2048)
layers = cfg.num_hidden_layers
hs = cfg.hidden_size
n_shared = getattr(cfg, "n_shared_experts", 0)
# MoE intermediate size: try multiple field names
moe_int_size = getattr(cfg, "moe_intermediate_size", None)
if moe_int_size is None:
moe_int_size = getattr(cfg, "expert_ffn_hidden_size", cfg.intermediate_size)
# Dense FFN intermediate size
dense_ffn_hs = getattr(cfg, "intermediate_size", None)
if dense_ffn_hs is None:
dense_ffn_hs = getattr(cfg, "ffn_hidden_size", moe_int_size)
# Router top-k: try multiple field names
moe_topk = getattr(cfg, "num_experts_per_tok", None)
if moe_topk is None:
moe_topk = getattr(cfg, "moe_topk", 1)
moe_layer_pattern = _build_moe_layer_pattern(cfg, layers)
# MTP: try multiple field names used by different models
mtp = getattr(cfg, "num_nextn_predict_layers", None)
if mtp is None:
mtp = getattr(cfg, "mtp_num_layers", 0)
mtp = mtp or 0
# DSA (Dynamic Sparse Attention) indexer fields
idx_topk = getattr(cfg, "index_topk", None)
idx_n_heads = getattr(cfg, "index_n_heads", 0)
idx_head_dim = getattr(cfg, "index_head_dim", 0)
return _mla_moe_model_flops(
gbs=gbs,
seq_len=seq_len,
hs=hs,
layers=layers,
attention_heads=cfg.num_attention_heads,
vocab_size=cfg.vocab_size,
q_lora_rank=getattr(cfg, "q_lora_rank", None),
kv_lora_rank=cfg.kv_lora_rank,
qk_rope_head_dim=cfg.qk_rope_head_dim,
qk_nope_head_dim=cfg.qk_nope_head_dim,
v_head_dim=cfg.v_head_dim,
dense_ffn_hs=dense_ffn_hs,
moe_ffn_hs=moe_int_size,
moe_router_topk=moe_topk,
moe_shared_expert_hs=moe_int_size * n_shared,
moe_layer_pattern=moe_layer_pattern,
mtp_num_layers=mtp,
index_topk=idx_topk,
index_n_heads=idx_n_heads,
index_head_dim=idx_head_dim,
)
[docs]
def step3_5_flash_flops(config, gbs=1, seq_len=None):
"""Model FLOPs for Step3.5-Flash (GQA + sliding-window / full attention + MoE).
Architecture: hybrid full/SWA attention with different head counts per type,
MoE with shared expert on most layers, first few layers dense, SwiGLU.
"""
if seq_len is None:
seq_len = getattr(config, "max_position_embeddings", 2048)
layers = config.num_hidden_layers
hs = config.hidden_size
vocab_size = config.vocab_size
# Attention heads: full vs sliding may differ
full_attn_heads = config.num_attention_heads
attn_other = getattr(config, "attention_other_setting", None)
if attn_other is not None and isinstance(attn_other, dict):
sliding_attn_heads = attn_other.get("num_attention_heads", full_attn_heads)
else:
sliding_attn_heads = full_attn_heads
num_query_groups = getattr(config, "num_attention_groups", full_attn_heads)
head_dim = getattr(config, "head_dim", hs // full_attn_heads)
sliding_window = getattr(config, "sliding_window", 512)
# MoE config
moe_top_k = getattr(config, "moe_top_k", 8)
moe_ffn_hs = getattr(config, "moe_intermediate_size", 1280)
share_expert_dim = getattr(config, "share_expert_dim", moe_ffn_hs)
dense_ffn_hs = config.intermediate_size
# Which layers are MoE? Parse moe_layers_enum (comma-separated string or list)
moe_layers_raw = getattr(config, "moe_layers_enum", None)
if moe_layers_raw is not None:
if isinstance(moe_layers_raw, str):
moe_layers_set = set(int(x.strip()) for x in moe_layers_raw.split(",") if x.strip())
else:
moe_layers_set = set(int(x) for x in moe_layers_raw)
else:
# Default: first 3 dense, rest MoE
moe_layers_set = set(range(3, layers))
# Layer types (first `layers` entries; remaining are MTP layers)
layer_types = getattr(config, "layer_types", None)
# MTP
mtp_num_layers = getattr(config, "num_nextn_predict_layers", 0) or 0
# --- Per-layer FLOPs ---
total_attn = 0
total_mlp = 0
for i in range(layers):
# Determine attention type
if layer_types and i < len(layer_types):
is_full = layer_types[i] == "full_attention"
else:
is_full = i % 4 == 0 # default: every 4th starting from 0
if is_full:
total_attn += attention_flops_calculator(
seq_len,
hs,
full_attn_heads,
num_query_groups,
head_dim,
is_swa=False,
)
else:
total_attn += attention_flops_calculator(
seq_len,
hs,
sliding_attn_heads,
num_query_groups,
head_dim,
is_swa=True,
swa_window_size=sliding_window,
)
# MLP: MoE or dense (SwiGLU = gate + up + down = 3 projections)
if i in moe_layers_set:
total_mlp += moe_mlp_flops_calculator(
seq_len,
hs,
moe_ffn_hs,
moe_top_k,
gated_linear_unit=True,
)
# Shared expert (SwiGLU)
total_mlp += 6 * seq_len * hs * share_expert_dim * 3
else:
total_mlp += 6 * seq_len * hs * dense_ffn_hs * 3
# Vocab
total_vocab = loss_flops_calculator(seq_len, hs, vocab_size)
# MTP
mtp_total = 0
if mtp_num_layers > 0:
# Embedding projection per MTP module (2*hs -> hs)
mtp_total += 6 * seq_len * hs * 2 * hs * mtp_num_layers
# Each MTP module has one transformer layer (attention + MoE MLP)
mtp_total += (
attention_flops_calculator(
seq_len,
hs,
full_attn_heads,
num_query_groups,
head_dim,
is_swa=False,
)
* mtp_num_layers
)
mtp_total += (
moe_mlp_flops_calculator(
seq_len,
hs,
moe_ffn_hs,
moe_top_k,
gated_linear_unit=True,
)
* mtp_num_layers
)
mtp_total += 6 * seq_len * hs * share_expert_dim * 3 * mtp_num_layers
# Vocab per MTP module
mtp_total += loss_flops_calculator(seq_len, hs, vocab_size) * mtp_num_layers
return gbs * (total_attn + total_mlp + total_vocab + mtp_total)