nemo_automodel.components.moe.load_balance_metrics

View as Markdown

MoE load balance metrics utilities.

Provides functions to enable load balance tracking on Gate modules, collect per-layer expert load data, and compute brief/detailed metrics suitable for wandb logging.

Expert utilization is a ratio of current_load / ideal_load where ideal_load = total_tokens / n_experts. A value of 1.0 means the expert receives exactly its fair share; >1 = overloaded, <1 = underloaded, 0 = dead expert.

Expert diversity measures how many experts are meaningfully used:

  • dead_expert_frac: fraction of experts receiving zero tokens.
  • expert_diversity: exp(H) / N where H is Shannon entropy of the routing distribution (1.0 = all experts equally used, 0 = collapsed).

Modes:

  • brief: Aggregated scalars (mean/median/min/max of cv and expert utilization, diversity metrics) plus top-K/bottom-K individual expert utilization ratios.
  • detailed: Everything in brief, plus per-layer breakdowns (moe/layer_&#123;i&#125;/cv, moe/layer_&#123;i&#125;/utilization_mean, moe/layer_&#123;i&#125;/expert_diversity, etc.).

Module Contents

Functions

NameDescription
_aggregate_diversity_metricsAggregate per-layer diversity metrics into model-level summaries.
_aggregate_statsCompute mean, median, min, max for a list of per-layer values.
_compute_diversity_per_layerCompute per-layer expert diversity metrics from utilization ratios.
_compute_expert_utilizationCompute top-K and bottom-K expert utilization ratios globally.
_compute_per_layer_statsCompute per-layer CV, aux_loss and per-layer utilization ratios.
_compute_utilization_aggregatesCompute aggregate utilization stats across all experts globally.
collect_expert_loadsCollect the most recent expert load data from all Gate modules.
compute_brief_metricsCompute brief load-balance metrics: aggregated scalars + top-K/bottom-K utilization.
compute_detailed_metricsCompute detailed load-balance metrics: per-layer scalars + aggregates + utilization.
enable_load_balance_trackingEnable load balance tracking on all Gate modules in the model.

API

nemo_automodel.components.moe.load_balance_metrics._aggregate_diversity_metrics(
per_layer_diversity: list[dict[str, float]],
per_layer: bool = False
) -> dict[str, float]

Aggregate per-layer diversity metrics into model-level summaries.

Parameters:

per_layer_diversity
list[dict[str, float]]

Output of :func:_compute_diversity_per_layer.

per_layer
boolDefaults to False

If True, also emit per-layer keys (for detailed mode).

Returns: dict[str, float]

Dict with moe/dead_expert_frac_mean, moe/expert_diversity_mean,

nemo_automodel.components.moe.load_balance_metrics._aggregate_stats(
values: list[float],
prefix: str
) -> dict[str, float]

Compute mean, median, min, max for a list of per-layer values.

Parameters:

values
list[float]

List of per-layer scalar values.

prefix
str

Key prefix, e.g. "moe/cv".

Returns: dict[str, float]

Dict like &#123;"moe/cv_mean": .., "moe/cv_median": .., "moe/cv_min": .., "moe/cv_max": ..&#125;.

nemo_automodel.components.moe.load_balance_metrics._compute_diversity_per_layer(
per_layer_utilizations: list[torch.Tensor]
) -> list[dict[str, float]]

Compute per-layer expert diversity metrics from utilization ratios.

For each layer, computes:

  • dead_expert_frac: fraction of experts with zero load.
  • expert_diversity: exp(H) / N where H is Shannon entropy of the token routing distribution. 1.0 = perfectly uniform, 1/N = all tokens routed to a single expert, 0 = all-zero load.

Parameters:

per_layer_utilizations
list[torch.Tensor]

List of Tensor[n_experts] with utilization ratios (1.0 = ideal load).

Returns: list[dict[str, float]]

List of dicts (one per layer) with the metrics above.

nemo_automodel.components.moe.load_balance_metrics._compute_expert_utilization(
per_layer_utilizations: list[torch.Tensor],
top_k: int = 5
) -> dict[str, float]

Compute top-K and bottom-K expert utilization ratios globally.

Flattens utilization across all layers and experts, then emits only the top_k highest and top_k lowest entries. This keeps the total number of wandb keys to at most 2 * top_k regardless of model size.

Values are ratios relative to ideal load: 1.0 = perfect balance, >1 = overloaded, <1 = underloaded, 0 = dead expert.

All keys share the moe_expert_utilization/ prefix so wandb renders them on a single chart.

Parameters:

per_layer_utilizations
list[torch.Tensor]

List of Tensor[n_experts] utilization ratio per layer.

top_k
intDefaults to 5

Number of top (highest) and bottom (lowest) experts to emit.

Returns: dict[str, float]

Dict like &#123;"moe_expert_utilization/layer_0_expert_5": 1.23, ...&#125;.

nemo_automodel.components.moe.load_balance_metrics._compute_per_layer_stats(
layer_loads: dict[str, dict]
)

Compute per-layer CV, aux_loss and per-layer utilization ratios.

Returns:

(per_layer_metrics, per_layer_utilizations) where:

nemo_automodel.components.moe.load_balance_metrics._compute_utilization_aggregates(
per_layer_utilizations: list[torch.Tensor],
per_layer: bool = False
) -> dict[str, float]

Compute aggregate utilization stats across all experts globally.

Parameters:

per_layer_utilizations
list[torch.Tensor]

List of Tensor[n_experts] utilization ratio per layer.

per_layer
boolDefaults to False

If True, also emit per-layer utilization means.

Returns: dict[str, float]

Dict with moe/expert_utilization_&#123;p25,median,p75,min,max&#125; and optionally

nemo_automodel.components.moe.load_balance_metrics.collect_expert_loads(
model: torch.nn.Module,
dp_group: torch.distributed.ProcessGroup | None = None
) -> dict[str, dict]

Collect the most recent expert load data from all Gate modules.

When dp_group is provided, expert loads are all-reduced across the data-parallel group so the metrics reflect global token routing rather than a single rank’s view. This is important when DP > 1 or EP > 1 because each rank only routes its local shard of tokens through the (replicated) gate.

Parameters:

model
nn.Module

The model (or model part) to collect from.

dp_group
torch.distributed.ProcessGroup | NoneDefaults to None

Optional DP (or DP+CP) process group for all-reducing expert loads. Pass None to skip reduction (rank-local view).

Returns: dict[str, dict]

Dictionary mapping layer names to dicts with keys:

nemo_automodel.components.moe.load_balance_metrics.compute_brief_metrics(
layer_loads: dict[str, dict],
top_k: int = 5
) -> dict[str, float]

Compute brief load-balance metrics: aggregated scalars + top-K/bottom-K utilization.

Metrics produced:

  • moe/cv_&#123;mean,median,min,max&#125; — CV aggregated across all MoE layers.
  • moe/expert_utilization_&#123;p25,median,p75,min,max&#125; — utilization ratio stats across all experts globally (1.0 = ideal).
  • moe/aux_loss_mean — aux loss averaged across layers (when available).
  • moe_expert_utilization/layer_&#123;i&#125;_expert_&#123;j&#125; — top-K highest and bottom-K lowest utilization experts globally.

Parameters:

layer_loads
dict[str, dict]

Output of :func:collect_expert_loads.

top_k
intDefaults to 5

Number of top/bottom experts to emit globally.

Returns: dict[str, float]

Flat dictionary suitable for wandb.log().

nemo_automodel.components.moe.load_balance_metrics.compute_detailed_metrics(
layer_loads: dict[str, dict],
top_k: int = 5
) -> dict[str, float]

Compute detailed load-balance metrics: per-layer scalars + aggregates + utilization.

Includes everything from :func:compute_brief_metrics plus per-layer breakdowns:

  • moe/layer_&#123;i&#125;/cv, moe/layer_&#123;i&#125;/aux_loss
  • moe/layer_&#123;i&#125;/utilization_mean — per-layer mean utilization.

Parameters:

layer_loads
dict[str, dict]

Output of :func:collect_expert_loads.

top_k
intDefaults to 5

Number of top/bottom experts to emit globally.

Returns: dict[str, float]

Flat dictionary suitable for wandb.log().

nemo_automodel.components.moe.load_balance_metrics.enable_load_balance_tracking(
model: torch.nn.Module
) -> None

Enable load balance tracking on all Gate modules in the model.

Sets _track_load_balance = True on every Gate instance found via model.modules(). This causes each Gate to store its most recent expert_load tensor after every forward pass with negligible overhead (one .detach() copy per layer).

Parameters:

model
nn.Module

The model (or model part) to enable tracking on.