nemo_automodel.components.moe.load_balance_metrics#
MoE load balance metrics utilities.
Provides functions to enable load balance tracking on Gate modules, collect per-layer expert load data, and compute brief/detailed metrics suitable for wandb logging.
Expert utilization is a ratio of current_load / ideal_load where
ideal_load = total_tokens / n_experts. A value of 1.0 means the
expert receives exactly its fair share; >1 = overloaded, <1 = underloaded,
0 = dead expert.
Modes:
brief: Aggregated scalars (mean/median/min/max of cv and expert utilization) plus top-K/bottom-K individual expert utilization ratios.
detailed: Everything in brief, plus per-layer breakdowns (
moe/layer_{i}/cv,moe/layer_{i}/utilization_mean, etc.).
Module Contents#
Functions#
Enable load balance tracking on all Gate modules in the model. |
|
Collect the most recent expert load data from all Gate modules. |
|
Compute per-layer CV, aux_loss and per-layer utilization ratios. |
|
Compute mean, median, min, max for a list of per-layer values. |
|
Compute top-K and bottom-K expert utilization ratios globally. |
|
Compute aggregate utilization stats across all experts globally. |
|
Compute brief load-balance metrics: aggregated scalars + top-K/bottom-K utilization. |
|
Compute detailed load-balance metrics: per-layer scalars + aggregates + utilization. |
API#
- nemo_automodel.components.moe.load_balance_metrics.enable_load_balance_tracking(model: torch.nn.Module) None#
Enable load balance tracking on all Gate modules in the model.
Sets
_track_load_balance = Trueon every Gate instance found viamodel.modules(). This causes each Gate to store its most recentexpert_loadtensor after every forward pass with negligible overhead (one.detach()copy per layer).- Parameters:
model β The model (or model part) to enable tracking on.
- nemo_automodel.components.moe.load_balance_metrics.collect_expert_loads(
- model: torch.nn.Module,
- dp_group: torch.distributed.ProcessGroup | None = None,
Collect the most recent expert load data from all Gate modules.
When
dp_groupis provided, expert loads are all-reduced across the data-parallel group so the metrics reflect global token routing rather than a single rankβs view. This is important when DP > 1 or EP > 1 because each rank only routes its local shard of tokens through the (replicated) gate.- Parameters:
model β The model (or model part) to collect from.
dp_group β Optional DP (or DP+CP) process group for all-reducing expert loads. Pass
Noneto skip reduction (rank-local view).
- Returns:
"expert_load":Tensor[n_experts]with token counts per expert."aux_loss":Optional[Tensor]scalar aux loss (if computed)."n_experts":intnumber of routed experts.
- Return type:
Dictionary mapping layer names to dicts with keys
- nemo_automodel.components.moe.load_balance_metrics._compute_per_layer_stats(layer_loads: dict[str, dict])#
Compute per-layer CV, aux_loss and per-layer utilization ratios.
- Returns:
per_layer_metrics: list of dicts with keys cv, aux_loss
per_layer_utilizations: list of Tensor[n_experts] with utilization ratio per layer (1.0 = ideal, >1 = overloaded, <1 = underloaded)
- Return type:
(per_layer_metrics, per_layer_utilizations) where
- nemo_automodel.components.moe.load_balance_metrics._aggregate_stats(values: list[float], prefix: str) dict[str, float]#
Compute mean, median, min, max for a list of per-layer values.
- Parameters:
values β List of per-layer scalar values.
prefix β Key prefix, e.g.
"moe/cv".
- Returns:
Dict like
{"moe/cv_mean": .., "moe/cv_median": .., "moe/cv_min": .., "moe/cv_max": ..}.
- nemo_automodel.components.moe.load_balance_metrics._compute_expert_utilization(
- per_layer_utilizations: list[torch.Tensor],
- top_k: int = 5,
Compute top-K and bottom-K expert utilization ratios globally.
Flattens utilization across all layers and experts, then emits only the
top_khighest andtop_klowest entries. This keeps the total number of wandb keys to at most2 * top_kregardless of model size.Values are ratios relative to ideal load: 1.0 = perfect balance,
1 = overloaded, <1 = underloaded, 0 = dead expert.
All keys share the
moe_expert_utilization/prefix so wandb renders them on a single chart.- Parameters:
per_layer_utilizations β List of Tensor[n_experts] utilization ratio per layer.
top_k β Number of top (highest) and bottom (lowest) experts to emit.
- Returns:
Dict like
{"moe_expert_utilization/layer_0_expert_5": 1.23, ...}.
- nemo_automodel.components.moe.load_balance_metrics._compute_utilization_aggregates(
- per_layer_utilizations: list[torch.Tensor],
- per_layer: bool = False,
Compute aggregate utilization stats across all experts globally.
- Parameters:
per_layer_utilizations β List of Tensor[n_experts] utilization ratio per layer.
per_layer β If True, also emit per-layer utilization means.
- Returns:
Dict with
moe/expert_utilization_{p25,median,p75,min,max}and optionallymoe/layer_{i}/utilization_meanwhenper_layer=True.
- nemo_automodel.components.moe.load_balance_metrics.compute_brief_metrics(
- layer_loads: dict[str, dict],
- top_k: int = 5,
Compute brief load-balance metrics: aggregated scalars + top-K/bottom-K utilization.
Metrics produced:
moe/cv_{mean,median,min,max}β CV aggregated across all MoE layers.moe/expert_utilization_{p25,median,p75,min,max}β utilization ratio stats across all experts globally (1.0 = ideal).moe/aux_loss_meanβ aux loss averaged across layers (when available).moe_expert_utilization/layer_{i}_expert_{j}β top-K highest and bottom-K lowest utilization experts globally.
- Parameters:
layer_loads β Output of :func:
collect_expert_loads.top_k β Number of top/bottom experts to emit globally.
- Returns:
Flat dictionary suitable for
wandb.log().
- nemo_automodel.components.moe.load_balance_metrics.compute_detailed_metrics(
- layer_loads: dict[str, dict],
- top_k: int = 5,
Compute detailed load-balance metrics: per-layer scalars + aggregates + utilization.
Includes everything from :func:
compute_brief_metricsplus per-layer breakdowns:moe/layer_{i}/cv,moe/layer_{i}/aux_lossmoe/layer_{i}/utilization_meanβ per-layer mean utilization.
- Parameters:
layer_loads β Output of :func:
collect_expert_loads.top_k β Number of top/bottom experts to emit globally.
- Returns:
Flat dictionary suitable for
wandb.log().