nemo_automodel.components.moe.load_balance_metrics#

MoE load balance metrics utilities.

Provides functions to enable load balance tracking on Gate modules, collect per-layer expert load data, and compute brief/detailed metrics suitable for wandb logging.

Expert utilization is a ratio of current_load / ideal_load where ideal_load = total_tokens / n_experts. A value of 1.0 means the expert receives exactly its fair share; >1 = overloaded, <1 = underloaded, 0 = dead expert.

Modes:

  • brief: Aggregated scalars (mean/median/min/max of cv and expert utilization) plus top-K/bottom-K individual expert utilization ratios.

  • detailed: Everything in brief, plus per-layer breakdowns (moe/layer_{i}/cv, moe/layer_{i}/utilization_mean, etc.).

Module Contents#

Functions#

enable_load_balance_tracking

Enable load balance tracking on all Gate modules in the model.

collect_expert_loads

Collect the most recent expert load data from all Gate modules.

_compute_per_layer_stats

Compute per-layer CV, aux_loss and per-layer utilization ratios.

_aggregate_stats

Compute mean, median, min, max for a list of per-layer values.

_compute_expert_utilization

Compute top-K and bottom-K expert utilization ratios globally.

_compute_utilization_aggregates

Compute aggregate utilization stats across all experts globally.

compute_brief_metrics

Compute brief load-balance metrics: aggregated scalars + top-K/bottom-K utilization.

compute_detailed_metrics

Compute detailed load-balance metrics: per-layer scalars + aggregates + utilization.

API#

nemo_automodel.components.moe.load_balance_metrics.enable_load_balance_tracking(model: torch.nn.Module) None#

Enable load balance tracking on all Gate modules in the model.

Sets _track_load_balance = True on every Gate instance found via model.modules(). This causes each Gate to store its most recent expert_load tensor after every forward pass with negligible overhead (one .detach() copy per layer).

Parameters:

model – The model (or model part) to enable tracking on.

nemo_automodel.components.moe.load_balance_metrics.collect_expert_loads(
model: torch.nn.Module,
dp_group: torch.distributed.ProcessGroup | None = None,
) dict[str, dict]#

Collect the most recent expert load data from all Gate modules.

When dp_group is provided, expert loads are all-reduced across the data-parallel group so the metrics reflect global token routing rather than a single rank’s view. This is important when DP > 1 or EP > 1 because each rank only routes its local shard of tokens through the (replicated) gate.

Parameters:
  • model – The model (or model part) to collect from.

  • dp_group – Optional DP (or DP+CP) process group for all-reducing expert loads. Pass None to skip reduction (rank-local view).

Returns:

  • "expert_load": Tensor[n_experts] with token counts per expert.

  • "aux_loss": Optional[Tensor] scalar aux loss (if computed).

  • "n_experts": int number of routed experts.

Return type:

Dictionary mapping layer names to dicts with keys

nemo_automodel.components.moe.load_balance_metrics._compute_per_layer_stats(layer_loads: dict[str, dict])#

Compute per-layer CV, aux_loss and per-layer utilization ratios.

Returns:

  • per_layer_metrics: list of dicts with keys cv, aux_loss

  • per_layer_utilizations: list of Tensor[n_experts] with utilization ratio per layer (1.0 = ideal, >1 = overloaded, <1 = underloaded)

Return type:

(per_layer_metrics, per_layer_utilizations) where

nemo_automodel.components.moe.load_balance_metrics._aggregate_stats(values: list[float], prefix: str) dict[str, float]#

Compute mean, median, min, max for a list of per-layer values.

Parameters:
  • values – List of per-layer scalar values.

  • prefix – Key prefix, e.g. "moe/cv".

Returns:

Dict like {"moe/cv_mean": .., "moe/cv_median": .., "moe/cv_min": .., "moe/cv_max": ..}.

nemo_automodel.components.moe.load_balance_metrics._compute_expert_utilization(
per_layer_utilizations: list[torch.Tensor],
top_k: int = 5,
) dict[str, float]#

Compute top-K and bottom-K expert utilization ratios globally.

Flattens utilization across all layers and experts, then emits only the top_k highest and top_k lowest entries. This keeps the total number of wandb keys to at most 2 * top_k regardless of model size.

Values are ratios relative to ideal load: 1.0 = perfect balance,

1 = overloaded, <1 = underloaded, 0 = dead expert.

All keys share the moe_expert_utilization/ prefix so wandb renders them on a single chart.

Parameters:
  • per_layer_utilizations – List of Tensor[n_experts] utilization ratio per layer.

  • top_k – Number of top (highest) and bottom (lowest) experts to emit.

Returns:

Dict like {"moe_expert_utilization/layer_0_expert_5": 1.23, ...}.

nemo_automodel.components.moe.load_balance_metrics._compute_utilization_aggregates(
per_layer_utilizations: list[torch.Tensor],
per_layer: bool = False,
) dict[str, float]#

Compute aggregate utilization stats across all experts globally.

Parameters:
  • per_layer_utilizations – List of Tensor[n_experts] utilization ratio per layer.

  • per_layer – If True, also emit per-layer utilization means.

Returns:

Dict with moe/expert_utilization_{p25,median,p75,min,max} and optionally moe/layer_{i}/utilization_mean when per_layer=True.

nemo_automodel.components.moe.load_balance_metrics.compute_brief_metrics(
layer_loads: dict[str, dict],
top_k: int = 5,
) dict[str, float]#

Compute brief load-balance metrics: aggregated scalars + top-K/bottom-K utilization.

Metrics produced:

  • moe/cv_{mean,median,min,max} β€” CV aggregated across all MoE layers.

  • moe/expert_utilization_{p25,median,p75,min,max} β€” utilization ratio stats across all experts globally (1.0 = ideal).

  • moe/aux_loss_mean β€” aux loss averaged across layers (when available).

  • moe_expert_utilization/layer_{i}_expert_{j} β€” top-K highest and bottom-K lowest utilization experts globally.

Parameters:
  • layer_loads – Output of :func:collect_expert_loads.

  • top_k – Number of top/bottom experts to emit globally.

Returns:

Flat dictionary suitable for wandb.log().

nemo_automodel.components.moe.load_balance_metrics.compute_detailed_metrics(
layer_loads: dict[str, dict],
top_k: int = 5,
) dict[str, float]#

Compute detailed load-balance metrics: per-layer scalars + aggregates + utilization.

Includes everything from :func:compute_brief_metrics plus per-layer breakdowns:

  • moe/layer_{i}/cv, moe/layer_{i}/aux_loss

  • moe/layer_{i}/utilization_mean β€” per-layer mean utilization.

Parameters:
  • layer_loads – Output of :func:collect_expert_loads.

  • top_k – Number of top/bottom experts to emit globally.

Returns:

Flat dictionary suitable for wandb.log().