Source code for nemo_rl.utils.flops_formulas

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import List, Optional, Union


# lifted from NeMo/nemo/utils/flops_formulas.py
[docs] @dataclass class FLOPSConfig: """Contains the model hparams needed for FLOPS computations.""" gbs: int enc_seq_len: Optional[int] = None hs: Optional[int] = None layers: Optional[int] = None ffn_hs: Optional[int] = None attention_heads: Optional[int] = None moe_router_topk: Optional[int] = None query_groups: Optional[int] = None img_seq_len: Optional[int] = None img_h: Optional[int] = None img_w: Optional[int] = None in_channels: Optional[int] = None patch_dim: Optional[int] = None class_token_len: Optional[int] = None projector_type: Optional[str] = None inp_s: Optional[int] = None model_pattern: Optional[str] = None vocab_size: Optional[int] = None model_channels: Optional[int] = None vec_in_dim: Optional[int] = None q_lora_rank: Optional[int] = None kv_lora_rank: Optional[int] = None qk_head_dim: Optional[int] = None qk_pos_emb_head_dim: Optional[int] = None v_head_dim: Optional[int] = None moe_layer_freq: Optional[Union[int, List[int]]] = None moe_shared_expert_intermediate_size: Optional[int] = None moe_ffn_hidden_size: Optional[int] = None mtp_num_layers: Optional[int] = None causal_self_attn: Optional[bool] = None is_hybrid_model: bool = False hybrid_override_pattern: Optional[str] = None mamba_state_dim: Optional[int] = None mamba_head_dim: Optional[int] = None mamba_num_groups: Optional[int] = None mamba_num_heads: Optional[int] = None
[docs] def gpt3(config: FLOPSConfig): """Model FLOPs for GPT3 family.""" return ( 24 * config.gbs * config.enc_seq_len * config.hs * config.hs + 4 * config.gbs * config.enc_seq_len * config.enc_seq_len * config.hs ) * (3 * config.layers) + ( 6 * config.gbs * config.enc_seq_len * config.hs * config.vocab_size )
[docs] def llama2(config: FLOPSConfig): """Model FLOPs for llama2 family.""" return ( config.gbs * config.enc_seq_len * config.layers * config.hs * config.hs * ( 12 + (12 * config.query_groups / config.attention_heads) + (18 * config.ffn_hs / config.hs) + (12 * config.enc_seq_len / config.hs) + (6 * config.vocab_size / (config.layers * config.hs)) ) )
[docs] def llama3(config: FLOPSConfig): """Model FLOPs for llama3 family.""" return ( config.gbs * config.enc_seq_len * config.layers * config.hs * config.hs * ( 12 + (12 * config.query_groups / config.attention_heads) + (18 * config.ffn_hs / config.hs) + (12 * config.enc_seq_len / config.hs) + (6 * config.vocab_size / (config.layers * config.hs)) ) )
[docs] def nemotron(config: FLOPSConfig): """Model FLOPs for nemotron family.""" return ( config.gbs * config.enc_seq_len * config.layers * config.hs * config.hs * ( 12 + (12 * config.query_groups / config.attention_heads) + (12 * config.ffn_hs / config.hs) + (12 * config.enc_seq_len / config.hs) + (6 * config.vocab_size / (config.layers * config.hs)) ) )
[docs] def mixtral(config: FLOPSConfig): """Model FLOPs for mixtral family.""" return ( config.gbs * config.enc_seq_len * config.layers * config.hs * config.hs * ( 12 + (12 * config.query_groups / config.attention_heads) + (18 * config.moe_router_topk * config.ffn_hs / config.hs) + (12 * config.enc_seq_len / config.hs) + (6 * config.vocab_size / (config.layers * config.hs)) ) )
[docs] def qwen2(config: FLOPSConfig): """Model FLOPs for Qwen2 family.""" causal_self_attn = True seq_len = config.enc_seq_len hidden_size = config.hs gated_linear_multiplier = 2 # attention flops for GQA attention_flops = ( 3 * 2 * config.gbs * config.layers * seq_len * hidden_size * hidden_size * ( (2 + 1) # QKV gemm + ( seq_len / hidden_size * 2 * (0.5 if causal_self_attn else 1) ) # attention + 1 # attention proj gemm ) ) # mlp flops mlp_flops = ( 3 * 2 * config.gbs * config.layers * seq_len * hidden_size * (1 + gated_linear_multiplier) * config.ffn_hs ) # vocab flops vocab_flops = 3 * 2 * config.gbs * seq_len * hidden_size * config.vocab_size return attention_flops + mlp_flops + vocab_flops
[docs] def qwen3(config: FLOPSConfig): """Model FLOPs for Qwen3 family.""" causal_self_attn = True seq_len = config.enc_seq_len hidden_size = config.hs gated_linear_multiplier = 2 # attention flops for GQA attention_flops = ( 3 * 2 * config.gbs * config.layers * seq_len * hidden_size * hidden_size * ( (config.query_groups / config.attention_heads * 2 + 1) # QKV gemm + ( seq_len / hidden_size * 2 * (0.5 if causal_self_attn else 1) ) # attention + 1 # attention proj gemm ) ) # mlp flops mlp_flops = ( 3 * 2 * config.gbs * config.layers * seq_len * hidden_size * (1 + gated_linear_multiplier) * (config.moe_ffn_hidden_size * config.moe_router_topk) # MoE layers ) # vocab flops vocab_flops = 3 * 2 * config.gbs * seq_len * hidden_size * config.vocab_size return attention_flops + mlp_flops + vocab_flops
[docs] def bert(config: FLOPSConfig): """Model FLOPs for BERT family.""" return ( 72 * config.gbs * config.layers * config.enc_seq_len * config.hs * config.hs * ( 1 + (config.enc_seq_len / (6 * config.hs)) + (config.vocab_size / (12 * config.hs * config.layers)) ) )
[docs] def transformer(config: FLOPSConfig): """Calculate FLOPs for a standard Transformer model. Note: This does not cover encoder-decoder models. """ # Extract parameters from config batch_size = config.gbs hidden_size = config.hs seq_length = config.enc_seq_len num_layers = config.layers num_attention_heads = config.attention_heads ffn_hidden_size = config.ffn_hs vocab_size = config.vocab_size if vocab_size is None: raise ValueError("vocab_size is required for transformer FLOPs calculation") # Handle optional parameters with reasonable defaults query_groups = ( config.query_groups if config.query_groups is not None else num_attention_heads ) causal_self_attn = ( config.causal_self_attn if config.causal_self_attn is not None else False ) moe_router_topk = ( config.moe_router_topk if config.moe_router_topk is not None else 0 ) kv_channels = hidden_size // num_attention_heads # Standard dimension per head # Calculate query projection size and ratio query_projection_size = kv_channels * num_attention_heads query_projection_to_hidden_size_ratio = query_projection_size / hidden_size # MoE parameters - simplified for NeMo config # In this implementation, we assume all layers are dense if num_experts is None if moe_router_topk == 0: num_dense_layers = num_layers num_moe_layers = 0 num_experts_routed_to = 0 else: # Simplified MoE handling - assuming uniform distribution of MoE layers # This can be expanded based on NeMo's actual MoE implementation num_moe_layers = num_layers // 2 # Simplified assumption num_dense_layers = num_layers - num_moe_layers num_experts_routed_to = moe_router_topk # Handle SwiGLU vs standard GELU/ReLU # Default to standard activation (no SwiGLU) gated_linear_multiplier = 1 # Define the expansion factor as described in the paper # 3x: Each GEMM needs forward pass, backward wgrad, and backward dgrad # 2x: GEMMs are stacked twice in standard Transformer architectures # 2x: A GEMM of m*n with n*k requires 2mnk floating-point operations expansion_factor = 3 * 2 * 2 # Attention if not causal_self_attn: attention_component = ( 1 + (query_groups / num_attention_heads) # Only half of the attention matrix is non-zero and needs to be multiplied with V + (seq_length / hidden_size) # If causal self attn -> divide by 2. ) * query_projection_to_hidden_size_ratio else: attention_component = ( 1 + (query_groups / num_attention_heads) # Only half of the attention matrix is non-zero and needs to be multiplied with V + (seq_length / hidden_size / 2) # If causal self attn -> divide by 2. ) * query_projection_to_hidden_size_ratio # Calculate total FLOPs total_flops = ( expansion_factor * batch_size * seq_length * num_layers * hidden_size * hidden_size * ( attention_component # MLP component + ( ( # Dense layers (ffn_hidden_size * num_dense_layers) + # MoE layers ( ( # Routed experts ffn_hidden_size * num_experts_routed_to # Note: Shared experts are not implemented in this version ) * num_moe_layers ) ) * gated_linear_multiplier / (num_layers * hidden_size) ) # Logit component + (vocab_size / (2 * num_layers * hidden_size)) ) ) return total_flops
[docs] def flux(config: FLOPSConfig): """Model FLOPs for FLUX.""" hs = config.hs seq_len = config.model_channels + config.inp_s base_factor = 6 * config.gbs # common multiplier for most terms # Joint layer computations joint_layer_flops = ( base_factor * config.layers[0] * ( 10 * hs * hs # hidden size operations + 2 * hs * (config.model_channels + config.inp_s) * (1 + hs * 7) # channel and context joint attention + 2 * (config.model_channels + config.inp_s) * hs # final projection ) ) # Single layer computations single_layer_flops = ( base_factor * config.layers[1] * seq_len * hs * ( 3 # linear Y + 1 # Modulation + 4 * hs # Linear computations + (3 * hs + 2 * seq_len) # attention operations + 5 * hs # feed-forward + 1 # Modulation ) ) # Embedding and projection layers other_flops = base_factor * ( config.inp_s * config.in_channels * hs # image embedding + config.inp_s * hs * config.model_channels # text embedding + config.vec_in_dim * hs + hs * hs # vector embedding + 2 * (config.model_channels * hs + hs * hs) # guidance + timestep embedding + (config.inp_s * config.in_channels * hs) / config.gbs # final projection ) return joint_layer_flops + single_layer_flops + other_flops
[docs] def deepseekv3(config: FLOPSConfig): """Model FLOPs for DeepSeek V3.""" # self-attention flops bmm1_flops = ( 0.5 * (config.qk_head_dim + config.qk_pos_emb_head_dim) * config.attention_heads * (config.enc_seq_len**2) ) bmm2_flops = ( 0.5 * config.v_head_dim * config.attention_heads * (config.enc_seq_len**2) ) per_input_attention_flops = 6 * (bmm1_flops + bmm2_flops) * config.layers if config.mtp_num_layers is not None: per_input_attention_flops += ( 6 * (bmm1_flops + bmm2_flops) * config.mtp_num_layers ) # linear layer flops per_layer_mla_params = config.hs * config.q_lora_rank + config.q_lora_rank * ( (config.qk_head_dim + config.qk_pos_emb_head_dim) * config.attention_heads ) # Q per_layer_mla_params += config.hs * config.qk_pos_emb_head_dim # K^R per_layer_mla_params += config.hs * config.kv_lora_rank + config.kv_lora_rank * ( (config.qk_head_dim + config.v_head_dim) * config.attention_heads ) # K^C and V^C per_layer_mla_params += ( config.v_head_dim * config.attention_heads * config.hs ) # Proj mla_params = per_layer_mla_params * config.layers if config.mtp_num_layers is not None: mla_params += per_layer_mla_params * config.mtp_num_layers dense_layer_ffn_params = config.hs * config.ffn_hs * 3 # gated linear unit per_shared_expert_params = ( config.hs * config.moe_shared_expert_intermediate_size * 3 ) per_selected_expert_params = config.hs * config.moe_ffn_hidden_size * 3 ffn_params = 0 if isinstance(config.moe_layer_freq, int): moe_layer_pattern = [ 1 if (i % config.moe_layer_freq == 0) else 0 for i in range(config.layers) ] else: moe_layer_pattern = config.moe_layer_freq for i in moe_layer_pattern: if i == 0: ffn_params += dense_layer_ffn_params else: ffn_params += per_shared_expert_params + ( per_selected_expert_params * config.moe_router_topk ) if config.mtp_num_layers is not None: for i in range(config.mtp_num_layers): ffn_params += per_shared_expert_params + ( per_selected_expert_params * config.moe_router_topk ) per_input_params = mla_params + ffn_params per_input_linear_flops = 6 * per_input_params * config.enc_seq_len # vocab flops per_input_vocab_flops = 6 * config.vocab_size * config.hs * config.enc_seq_len if config.mtp_num_layers is not None: for i in range(config.mtp_num_layers): per_input_vocab_flops += ( 6 * config.vocab_size * config.hs * config.enc_seq_len ) per_input_vocab_flops += 6 * config.hs * 2 * config.hs * config.enc_seq_len return ( per_input_attention_flops + per_input_linear_flops + per_input_vocab_flops ) * config.gbs
[docs] def _mlp_layer_flops(config: FLOPSConfig): """Model FLOPs for MLP layer.""" return ( 6 * config.gbs * config.enc_seq_len * config.hs * config.ffn_hs * (2 if config.gated_linear_unit else 1) )
[docs] def _non_mla_attn_layer_flops(config: FLOPSConfig): """Model FLOPs for attention layer.""" return ( 6 * config.gbs * config.enc_seq_len * config.hs * ( config.hs # Q + config.query_groups / config.attention_heads * config.hs * 2 # KV + config.enc_seq_len / 2 * 2 + config.hs ) )
[docs] def _mamba_layer_flops(config: FLOPSConfig): """Model FLOPs for Mamba layer. We ignore part of the flops of scan because the chunk size is not known from model config.""" assert config.mamba_state_dim is not None assert config.mamba_head_dim is not None if config.mamba_num_heads: nheads = config.mamba_num_heads else: nheads = 2 * config.hs // config.mamba_head_dim # default expand is 2 d_in = nheads * config.mamba_head_dim return ( ( 6 * config.gbs * config.enc_seq_len * config.hs * (2 * d_in + 2 * config.mamba_num_groups * config.mamba_state_dim + nheads) ) + (3 * 2 * config.gbs * config.enc_seq_len * d_in * config.mamba_state_dim) + (6 * config.gbs * config.enc_seq_len * d_in * config.hs) )
[docs] def _hybrid_model_flops(config: FLOPSConfig): """Model FLOPs for hybrid model.""" assert config.is_hybrid_model == True assert config.hybrid_override_pattern is not None num_attn_layers, num_mamba_layers, num_mlp_layers = 0, 0, 0 for c in config.hybrid_override_pattern: if c == "M": num_mamba_layers += 1 elif c == "-": num_mlp_layers += 1 elif c == "*": num_attn_layers += 1 return ( num_attn_layers * _non_mla_attn_layer_flops(config) + num_mamba_layers * _mamba_layer_flops(config) + num_mlp_layers * _mlp_layer_flops(config) + 6 * config.gbs * config.enc_seq_len * config.hs * config.vocab_size )
[docs] def nemotronh(config: FLOPSConfig): """Model FLOPs for NemotronH.""" return _hybrid_model_flops(config)