nemo_automodel.components.moe.load_balance_metrics
nemo_automodel.components.moe.load_balance_metrics
MoE load balance metrics utilities.
Provides functions to enable load balance tracking on Gate modules, collect per-layer expert load data, and compute brief/detailed metrics suitable for wandb logging.
Expert utilization is a ratio of current_load / ideal_load where
ideal_load = total_tokens / n_experts. A value of 1.0 means the
expert receives exactly its fair share; >1 = overloaded, <1 = underloaded,
0 = dead expert.
Expert diversity measures how many experts are meaningfully used:
dead_expert_frac: fraction of experts receiving zero tokens.expert_diversity:exp(H) / NwhereHis Shannon entropy of the routing distribution (1.0 = all experts equally used, 0 = collapsed).
Modes:
- brief: Aggregated scalars (mean/median/min/max of cv and expert utilization, diversity metrics) plus top-K/bottom-K individual expert utilization ratios.
- detailed: Everything in brief, plus per-layer breakdowns
(
moe/layer_{i}/cv,moe/layer_{i}/utilization_mean,moe/layer_{i}/expert_diversity, etc.).
Module Contents
Functions
API
Aggregate per-layer diversity metrics into model-level summaries.
Parameters:
Output of :func:_compute_diversity_per_layer.
If True, also emit per-layer keys (for detailed mode).
Returns: dict[str, float]
Dict with moe/dead_expert_frac_mean, moe/expert_diversity_mean,
Compute mean, median, min, max for a list of per-layer values.
Parameters:
List of per-layer scalar values.
Key prefix, e.g. "moe/cv".
Returns: dict[str, float]
Dict like {"moe/cv_mean": .., "moe/cv_median": .., "moe/cv_min": .., "moe/cv_max": ..}.
Compute per-layer expert diversity metrics from utilization ratios.
For each layer, computes:
dead_expert_frac: fraction of experts with zero load.expert_diversity:exp(H) / NwhereHis Shannon entropy of the token routing distribution. 1.0 = perfectly uniform, 1/N = all tokens routed to a single expert, 0 = all-zero load.
Parameters:
List of Tensor[n_experts] with utilization
ratios (1.0 = ideal load).
Returns: list[dict[str, float]]
List of dicts (one per layer) with the metrics above.
Compute top-K and bottom-K expert utilization ratios globally.
Flattens utilization across all layers and experts, then emits only the
top_k highest and top_k lowest entries. This keeps the total
number of wandb keys to at most 2 * top_k regardless of model size.
Values are ratios relative to ideal load: 1.0 = perfect balance, >1 = overloaded, <1 = underloaded, 0 = dead expert.
All keys share the moe_expert_utilization/ prefix so wandb
renders them on a single chart.
Parameters:
List of Tensor[n_experts] utilization ratio per layer.
Number of top (highest) and bottom (lowest) experts to emit.
Returns: dict[str, float]
Dict like {"moe_expert_utilization/layer_0_expert_5": 1.23, ...}.
Compute per-layer CV, aux_loss and per-layer utilization ratios.
Returns:
(per_layer_metrics, per_layer_utilizations) where:
Compute aggregate utilization stats across all experts globally.
Parameters:
List of Tensor[n_experts] utilization ratio per layer.
If True, also emit per-layer utilization means.
Returns: dict[str, float]
Dict with moe/expert_utilization_{p25,median,p75,min,max} and optionally
Collect the most recent expert load data from all Gate modules.
When dp_group is provided, expert loads are all-reduced across the
data-parallel group so the metrics reflect global token routing rather
than a single rank’s view. This is important when DP > 1 or EP > 1
because each rank only routes its local shard of tokens through the
(replicated) gate.
Parameters:
The model (or model part) to collect from.
Optional DP (or DP+CP) process group for all-reducing
expert loads. Pass None to skip reduction (rank-local view).
Returns: dict[str, dict]
Dictionary mapping layer names to dicts with keys:
Compute brief load-balance metrics: aggregated scalars + top-K/bottom-K utilization.
Metrics produced:
moe/cv_{mean,median,min,max}— CV aggregated across all MoE layers.moe/expert_utilization_{p25,median,p75,min,max}— utilization ratio stats across all experts globally (1.0 = ideal).moe/aux_loss_mean— aux loss averaged across layers (when available).moe_expert_utilization/layer_{i}_expert_{j}— top-K highest and bottom-K lowest utilization experts globally.
Parameters:
Output of :func:collect_expert_loads.
Number of top/bottom experts to emit globally.
Returns: dict[str, float]
Flat dictionary suitable for wandb.log().
Compute detailed load-balance metrics: per-layer scalars + aggregates + utilization.
Includes everything from :func:compute_brief_metrics plus per-layer
breakdowns:
moe/layer_{i}/cv,moe/layer_{i}/aux_lossmoe/layer_{i}/utilization_mean— per-layer mean utilization.
Parameters:
Output of :func:collect_expert_loads.
Number of top/bottom experts to emit globally.
Returns: dict[str, float]
Flat dictionary suitable for wandb.log().
Enable load balance tracking on all Gate modules in the model.
Sets _track_load_balance = True on every Gate instance found via
model.modules(). This causes each Gate to store its most recent
expert_load tensor after every forward pass with negligible overhead
(one .detach() copy per layer).
Parameters:
The model (or model part) to enable tracking on.