> For clean Markdown of any page, append .md to the page URL.
> For a complete documentation index, see https://docs.nvidia.com/nemo/automodel/llms.txt.
> For AI client integration (Claude Code, Cursor, etc.), connect to the MCP server at https://docs.nvidia.com/nemo/automodel/_mcp/server.

# nemo_automodel.components.moe.load_balance_metrics

MoE load balance metrics utilities.

Provides functions to enable load balance tracking on Gate modules,
collect per-layer expert load data, and compute brief/detailed metrics
suitable for wandb logging.

Expert utilization is a ratio of `current_load / ideal_load` where
`ideal_load = total_tokens / n_experts`.  A value of 1.0 means the
expert receives exactly its fair share; >1 = overloaded, \<1 = underloaded,
0 = dead expert.

Expert diversity measures how many experts are meaningfully used:

* `dead_expert_frac`: fraction of experts receiving zero tokens.
* `expert_diversity`: `exp(H) / N` where `H` is Shannon entropy of the
  routing distribution (1.0 = all experts equally used, 0 = collapsed).

Modes:

* **brief**: Aggregated scalars (mean/median/min/max of cv and expert
  utilization, diversity metrics) plus top-K/bottom-K individual expert
  utilization ratios.
* **detailed**: Everything in brief, plus per-layer breakdowns
  (`moe/layer_&#123;i&#125;/cv`, `moe/layer_&#123;i&#125;/utilization_mean`,
  `moe/layer_&#123;i&#125;/expert_diversity`, etc.).

## Module Contents

### Functions

| Name                                                                                                                     | Description                                                                          |
| ------------------------------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------ |
| [`_aggregate_diversity_metrics`](#nemo_automodel-components-moe-load_balance_metrics-_aggregate_diversity_metrics)       | Aggregate per-layer diversity metrics into model-level summaries.                    |
| [`_aggregate_stats`](#nemo_automodel-components-moe-load_balance_metrics-_aggregate_stats)                               | Compute mean, median, min, max for a list of per-layer values.                       |
| [`_compute_diversity_per_layer`](#nemo_automodel-components-moe-load_balance_metrics-_compute_diversity_per_layer)       | Compute per-layer expert diversity metrics from utilization ratios.                  |
| [`_compute_expert_utilization`](#nemo_automodel-components-moe-load_balance_metrics-_compute_expert_utilization)         | Compute top-K and bottom-K expert utilization ratios globally.                       |
| [`_compute_per_layer_stats`](#nemo_automodel-components-moe-load_balance_metrics-_compute_per_layer_stats)               | Compute per-layer CV, aux\_loss and per-layer utilization ratios.                    |
| [`_compute_utilization_aggregates`](#nemo_automodel-components-moe-load_balance_metrics-_compute_utilization_aggregates) | Compute aggregate utilization stats across all experts globally.                     |
| [`collect_expert_loads`](#nemo_automodel-components-moe-load_balance_metrics-collect_expert_loads)                       | Collect the most recent expert load data from all Gate modules.                      |
| [`compute_brief_metrics`](#nemo_automodel-components-moe-load_balance_metrics-compute_brief_metrics)                     | Compute brief load-balance metrics: aggregated scalars + top-K/bottom-K utilization. |
| [`compute_detailed_metrics`](#nemo_automodel-components-moe-load_balance_metrics-compute_detailed_metrics)               | Compute detailed load-balance metrics: per-layer scalars + aggregates + utilization. |
| [`enable_load_balance_tracking`](#nemo_automodel-components-moe-load_balance_metrics-enable_load_balance_tracking)       | Enable load balance tracking on all Gate modules in the model.                       |

### API

```python
nemo_automodel.components.moe.load_balance_metrics._aggregate_diversity_metrics(
    per_layer_diversity: list[dict[str, float]],
    per_layer: bool = False
) -> dict[str, float]
```

Aggregate per-layer diversity metrics into model-level summaries.

**Parameters:**

Output of :func:`_compute_diversity_per_layer`.

If `True`, also emit per-layer keys (for detailed mode).

**Returns:** `dict[str, float]`

Dict with `moe/dead_expert_frac_mean`, `moe/expert_diversity_mean`,

```python
nemo_automodel.components.moe.load_balance_metrics._aggregate_stats(
    values: list[float],
    prefix: str
) -> dict[str, float]
```

Compute mean, median, min, max for a list of per-layer values.

**Parameters:**

List of per-layer scalar values.

Key prefix, e.g. `"moe/cv"`.

**Returns:** `dict[str, float]`

Dict like `&#123;"moe/cv_mean": .., "moe/cv_median": .., "moe/cv_min": .., "moe/cv_max": ..&#125;`.

```python
nemo_automodel.components.moe.load_balance_metrics._compute_diversity_per_layer(
    per_layer_utilizations: list[torch.Tensor]
) -> list[dict[str, float]]
```

Compute per-layer expert diversity metrics from utilization ratios.

For each layer, computes:

* `dead_expert_frac`: fraction of experts with zero load.
* `expert_diversity`: `exp(H) / N` where `H` is Shannon entropy of
  the token routing distribution.  1.0 = perfectly uniform, 1/N = all
  tokens routed to a single expert, 0 = all-zero load.

**Parameters:**

List of `Tensor[n_experts]` with utilization
ratios (1.0 = ideal load).

**Returns:** `list[dict[str, float]]`

List of dicts (one per layer) with the metrics above.

```python
nemo_automodel.components.moe.load_balance_metrics._compute_expert_utilization(
    per_layer_utilizations: list[torch.Tensor],
    top_k: int = 5
) -> dict[str, float]
```

Compute top-K and bottom-K expert utilization ratios globally.

Flattens utilization across all layers and experts, then emits only the
`top_k` highest and `top_k` lowest entries.  This keeps the total
number of wandb keys to at most `2 * top_k` regardless of model size.

Values are ratios relative to ideal load: 1.0 = perfect balance,
\>1 = overloaded, \<1 = underloaded, 0 = dead expert.

All keys share the `moe_expert_utilization/` prefix so wandb
renders them on a single chart.

**Parameters:**

List of Tensor\[n\_experts] utilization ratio per layer.

Number of top (highest) and bottom (lowest) experts to emit.

**Returns:** `dict[str, float]`

Dict like `&#123;"moe_expert_utilization/layer_0_expert_5": 1.23, ...&#125;`.

```python
nemo_automodel.components.moe.load_balance_metrics._compute_per_layer_stats(
    layer_loads: dict[str, dict]
)
```

Compute per-layer CV, aux\_loss and per-layer utilization ratios.

**Returns:**

(per\_layer\_metrics, per\_layer\_utilizations) where:

```python
nemo_automodel.components.moe.load_balance_metrics._compute_utilization_aggregates(
    per_layer_utilizations: list[torch.Tensor],
    per_layer: bool = False
) -> dict[str, float]
```

Compute aggregate utilization stats across all experts globally.

**Parameters:**

List of Tensor\[n\_experts] utilization ratio per layer.

If True, also emit per-layer utilization means.

**Returns:** `dict[str, float]`

Dict with `moe/expert_utilization_&#123;p25,median,p75,min,max&#125;` and optionally

```python
nemo_automodel.components.moe.load_balance_metrics.collect_expert_loads(
    model: torch.nn.Module,
    dp_group: torch.distributed.ProcessGroup | None = None
) -> dict[str, dict]
```

Collect the most recent expert load data from all Gate modules.

When `dp_group` is provided, expert loads are all-reduced across the
data-parallel group so the metrics reflect global token routing rather
than a single rank's view.  This is important when DP > 1 or EP > 1
because each rank only routes its local shard of tokens through the
(replicated) gate.

**Parameters:**

The model (or model part) to collect from.

Optional DP (or DP+CP) process group for all-reducing
expert loads.  Pass `None` to skip reduction (rank-local view).

**Returns:** `dict[str, dict]`

Dictionary mapping layer names to dicts with keys:

```python
nemo_automodel.components.moe.load_balance_metrics.compute_brief_metrics(
    layer_loads: dict[str, dict],
    top_k: int = 5
) -> dict[str, float]
```

Compute brief load-balance metrics: aggregated scalars + top-K/bottom-K utilization.

Metrics produced:

* `moe/cv_&#123;mean,median,min,max&#125;` — CV aggregated across all MoE layers.
* `moe/expert_utilization_&#123;p25,median,p75,min,max&#125;` — utilization ratio stats
  across all experts globally (1.0 = ideal).
* `moe/aux_loss_mean` — aux loss averaged across layers (when available).
* `moe_expert_utilization/layer_&#123;i&#125;_expert_&#123;j&#125;` — top-K highest and bottom-K
  lowest utilization experts globally.

**Parameters:**

Output of :func:`collect_expert_loads`.

Number of top/bottom experts to emit globally.

**Returns:** `dict[str, float]`

Flat dictionary suitable for `wandb.log()`.

```python
nemo_automodel.components.moe.load_balance_metrics.compute_detailed_metrics(
    layer_loads: dict[str, dict],
    top_k: int = 5
) -> dict[str, float]
```

Compute detailed load-balance metrics: per-layer scalars + aggregates + utilization.

Includes everything from :func:`compute_brief_metrics` plus per-layer
breakdowns:

* `moe/layer_&#123;i&#125;/cv`, `moe/layer_&#123;i&#125;/aux_loss`
* `moe/layer_&#123;i&#125;/utilization_mean` — per-layer mean utilization.

**Parameters:**

Output of :func:`collect_expert_loads`.

Number of top/bottom experts to emit globally.

**Returns:** `dict[str, float]`

Flat dictionary suitable for `wandb.log()`.

```python
nemo_automodel.components.moe.load_balance_metrics.enable_load_balance_tracking(
    model: torch.nn.Module
) -> None
```

Enable load balance tracking on all Gate modules in the model.

Sets `_track_load_balance = True` on every Gate instance found via
`model.modules()`.  This causes each Gate to store its most recent
`expert_load` tensor after every forward pass with negligible overhead
(one `.detach()` copy per layer).

**Parameters:**

The model (or model part) to enable tracking on.