core.transformer.moe.moe_utils#

Module Contents#

Classes#

MoEAuxLossAutoScaler

An AutoScaler that triggers the backward pass and scales the grad for auxiliary loss.

RandomSTE

Straight-Through Estimator(STE) function that returns random values with different seed for each rank.

RouterGatingLinearFunction

Autograd function for router gating linear.

Functions#

switch_load_balancing_loss_func

Calculate the auxiliary loss for load balancing. Refer to the Switch Transformer (https://arxiv.org/abs/2101.03961) and Global Load Balancing Loss(https://arxiv.org/abs/2501.11873) for details.

z_loss_func

Encourages the router’s logits to remain small to enhance stability. Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details.

sinkhorn

Sinkhorn based MoE routing function

get_capacity

Calculate the capacity of each expert.

permute

Permute the tokens and probs based on the mask. Tokens with the same designated expert will be grouped together. The shape of mask is [tokens, num_experts], it indicates which experts were selected by each token.

unpermute

Restore the original order of tokens after permutation. If probs are provided, it will also apply them to the tokens before restoring the order.

sort_chunks_by_idxs

Split and sort the input tensor based on the split_sizes and sorted indices.

group_limited_topk

Perform top-k routing on a subset of expert groups.

pad_routing_map

Pad the routing map to ensure each expert has a multiple of pad_multiple tokens.

topk_routing_with_score_function

Compute the routing probabilities and map for top-k selection with score function.

compute_routing_scores_for_aux_loss

Compute routing scores based on the score function.

apply_router_token_dropping

Apply token dropping to top-k expert selection.

save_to_aux_losses_tracker

Save the auxiliary loss for logging.

clear_aux_losses_tracker

Clear the auxiliary losses.

reduce_aux_losses_tracker_across_ranks

Collect and reduce the auxiliary losses across ranks.

track_moe_metrics

Track the MoE metrics for logging.

get_updated_expert_bias

Update expert bias for biased expert routing. See https://arxiv.org/abs/2408.15664v1#

maybe_move_tensor_to_cpu

Move a tensor to CPU if it is on GPU.

get_moe_layer_wise_logging_tracker

Return the moe layer wise tracker.

apply_random_logits

Apply the RandomSTE function to the logits.

router_gating_linear

Customized linear layer for router gating. This linear layer accepts bfloat16 input and weight, and can return output with router_dtype. It can reduce the memory usage by avoiding saving the intermediate high precision tensors.

get_align_size_for_quantization

Get the alignment size for quantization.

get_default_pg_collection

Get the default process groups for MoE.

Data#

API#

core.transformer.moe.moe_utils._MOE_LAYER_WISE_LOGGING_TRACKER#

None

core.transformer.moe.moe_utils.switch_load_balancing_loss_func(
probs: torch.Tensor,
tokens_per_expert: torch.Tensor,
total_num_tokens: int,
topk: int,
num_experts: int,
moe_aux_loss_coeff: float,
fused: bool = False,
)#

Calculate the auxiliary loss for load balancing. Refer to the Switch Transformer (https://arxiv.org/abs/2101.03961) and Global Load Balancing Loss(https://arxiv.org/abs/2501.11873) for details.

Detailed explanation of the auxiliary loss#

The formula for the auxiliary loss is: loss = E * Σ_{i=1}^{E} (f_i * P_i) where: f_i = 1 / (T * topk) * Σ_{x∈B} routing_map(x, i) (fraction of tokens dispatched to expert i) P_i = 1 / T * Σ_{x∈B} probs(x, i) (averaged router probability allocated for expert i) E is the number of experts T is the total number of tokens in the batch B

For distributed training with sequence or context parallelism, each rank can process a subset of the batch. loss = E * Σ_{i=1}^{E} (f_i * Σ_{j=1}^{N} P_ij) = E * Σ_{i=1}^{E} Σ_{j=1}^{N} (f_i * P_ij) = Σ_{j=1}^{N} E * (Σ_{i=1}^{E} f_i * P_ij)

where: f_i = 1 / (T * topk) * Σ_{x∈B} routing_map(x, i) (fraction of tokens dispatched to expert i in the global batch) P_ij = 1 / T * Σ_{x∈B_j} probs(x, i) (averaged router probability allocated for expert i in local batch of the j-th rank) N is the number of ranks B_j is the batch of tokens in the j-th rank T is the total number of tokens in the global batch B

Note: To calculate the auxiliary loss at different levels (micro-batch or global batch):

  • probs: Should always be from the local batch being processed

  • tokens_per_expert: Should represent token counts at the desired level (either micro-batch or global batch)

  • total_num_tokens: Should match the total token count at the same level as tokens_per_expert

#########################################################

param probs:

Softmax probabilities output by the router for each token. Shape in [num_tokens, num_experts].

type probs:

torch.Tensor

param tokens_per_expert:

Number of tokens assigned to each expert in the batch. Shape in [num_experts]

type tokens_per_expert:

torch.Tensor

param total_num_tokens:

Total number of tokens in the batch.

type total_num_tokens:

int

param topk:

The number of experts selected for each token.

type topk:

int

param num_experts:

The number of experts.

type num_experts:

int

param moe_aux_loss_coeff:

The coefficient for the auxiliary loss.

type moe_aux_loss_coeff:

float

returns:

The auxiliary loss for load balancing.

rtype:

torch.Tensor

core.transformer.moe.moe_utils.z_loss_func(logits, z_loss_coeff)#

Encourages the router’s logits to remain small to enhance stability. Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details.

Parameters:

logits (torch.Tensor) – The logits of the router.

Returns:

The logits after applying the z-loss.

Return type:

torch.Tensor

core.transformer.moe.moe_utils.sinkhorn(cost: torch.Tensor, tol: float = 0.0001)#

Sinkhorn based MoE routing function

core.transformer.moe.moe_utils.get_capacity(
num_tokens: int,
num_experts: int,
capacity_factor: float,
min_capacity=None,
)#

Calculate the capacity of each expert.

Parameters:
  • num_tokens (int) – num of the input tokens.

  • num_experts (int) – num of the experts.

  • capacity_factor (float) – Capacity factor.

  • min_capacity (int, optional) – Minimum capacity. Defaults to None.

Returns:

Capacity of each expert.

Return type:

Tensor

class core.transformer.moe.moe_utils.MoEAuxLossAutoScaler#

Bases: torch.autograd.Function

An AutoScaler that triggers the backward pass and scales the grad for auxiliary loss.

main_loss_backward_scale: torch.Tensor#

None

static forward(ctx, output: torch.Tensor, aux_loss: torch.Tensor)#

Preserve the aux_loss by storing it in the context to avoid garbage collection.

Parameters:
  • output (torch.Tensor) – The output tensor.

  • aux_loss (torch.Tensor) – The auxiliary loss tensor.

Returns:

The output tensor.

Return type:

torch.Tensor

static backward(ctx, grad_output: torch.Tensor)#

Compute and scale the gradient for auxiliary loss..

Parameters:

grad_output (torch.Tensor) – The gradient of the output.

Returns:

The gradient of the output, scaled auxiliary loss gradient.

Return type:

Tuple[torch.Tensor, torch.Tensor]

static set_loss_scale(scale: torch.Tensor)#

set the scale of the aux loss.

Parameters:

scale (torch.Tensor) – The scale value to set. Please ensure that the scale passed in matches the scale of the main_loss.

core.transformer.moe.moe_utils.permute(
tokens,
routing_map,
probs: Optional[torch.Tensor] = None,
num_out_tokens: Optional[int] = None,
fused: bool = False,
drop_and_pad: bool = False,
)#

Permute the tokens and probs based on the mask. Tokens with the same designated expert will be grouped together. The shape of mask is [tokens, num_experts], it indicates which experts were selected by each token.

When drop_and_pad=True, in routing_map, the number of non-zeros in each column equals to expert capacity. This function exploits this feature to use ops that support cuda graph.

Parameters:
  • tokens (torch.Tensor) – The input token tensor, [num_tokens, hidden].

  • routing_map (torch.Tensor) – The sparse token to expert mapping, [num_tokens, num_experts].

  • probs (torch.Tensor, optional) – The probs tensor, [num_tokens, num_experts].

  • num_out_tokens (int, optional) – The number of output tokens. If None, it’s set to the number of input tokens.

  • fused (bool, optional) – Whether use the fused permute function.

  • drop_and_pad (bool, optional) – Whether or not the token dispatcher uses token-drop and pads the number of tokens to the expert capacity. If set to true, routing_map has a fixed number of non-zeros in each column.

core.transformer.moe.moe_utils.unpermute(
permuted_tokens: torch.Tensor,
sorted_indices: torch.Tensor,
restore_shape: torch.Size,
probs: torch.Tensor = None,
routing_map: torch.Tensor = None,
fused: bool = False,
drop_and_pad: bool = False,
)#

Restore the original order of tokens after permutation. If probs are provided, it will also apply them to the tokens before restoring the order.

When drop_and_pad=True, the tensors will have the following properties:

  • In routing_map, the number of non-zeros in each column equals to expert capacity

  • The size of sorted_indices equals to num_experts * capacity, each split of capacity contains the indices of tokens routed to an expert. This function exploits these features to use ops that support cuda graph.

Parameters:
  • permuted_tokens (torch.Tensor) – The permuted token tensor.

  • sorted_indices (torch.Tensor) – The indices used to sort the tokens.

  • restore_shape (torch.Size) – The shape of the unpermuted tensor.

  • probs (torch.Tensor, optional) – The unpermuted probs tensor,

  • routing_map (torch.Tensor, optional) – Token to expert mapping, shape [num_tokens, num_experts].

  • fused (bool, optional) – Whether use the fused unpermute function.

  • drop_and_pad (bool, optional) – Whether or not the token dispatcher uses token-drop and pads the number of tokens to the expert capacity.

Returns:

The tokens restored to their original order.

Return type:

torch.Tensor

core.transformer.moe.moe_utils.sort_chunks_by_idxs(
input: torch.Tensor,
split_sizes: torch.Tensor,
sorted_idxs: torch.Tensor,
probs: Optional[torch.Tensor] = None,
fused: bool = False,
)#

Split and sort the input tensor based on the split_sizes and sorted indices.

core.transformer.moe.moe_utils.group_limited_topk(
scores: torch.Tensor,
topk: int,
num_tokens: int,
num_experts: int,
num_groups: int,
group_topk: int,
)#

Perform top-k routing on a subset of expert groups.

When using group-limited routing:

  1. Experts are divided into ‘moe_router_num_groups’ equal-sized groups

  2. For each token, ‘moe_router_group_topk’ groups are selected based on routing scores (specifically, the sum of top-2 expert scores within each group)

  3. From these selected groups, ‘moe_router_topk’ individual experts are chosen

Two common use cases:

  • Device-limited routing: Set ‘moe_router_num_groups’ equal to expert parallel size (EP) to limit each token to experts on a subset of devices (See DeepSeek-V2: https://arxiv.org/pdf/2405.04434)

  • Node-limited routing: Set ‘moe_router_num_groups’ equal to number of nodes in EP group to limit each token to experts on a subset of nodes (See DeepSeek-V3: https://arxiv.org/pdf/2412.19437)

Parameters:
  • scores (torch.Tensor) – Softmax scores generated by the router.

  • topk (int) – The number of experts to select for each token.

  • num_tokens (int) – The number of tokens.

  • num_experts (int) – The number of experts.

  • num_groups (int) – Number of groups for routed experts.

  • group_topk (int) – Number of groups selected for each token.

Returns:

Probs and indices tensor.

Return type:

Tuple[torch.Tensor, torch.Tensor]

core.transformer.moe.moe_utils.pad_routing_map(
routing_map: torch.Tensor,
pad_multiple: int,
) torch.Tensor#

Pad the routing map to ensure each expert has a multiple of pad_multiple tokens.

This function ensures that each expert has a number of tokens that is a multiple of pad_multiple by converting some 0s to 1s in the routing map. The padding is done by selecting the first N zero elements in each row, where N is the number needed to reach the next multiple of pad_multiple.

Parameters:
  • routing_map (torch.Tensor) – A boolean or integer tensor of shape [num_tokens, num_experts] indicating which tokens are routed to which experts.

  • pad_multiple (int) – The multiple to pad each expert’s token count to.

Returns:

The padded routing map of shape [num_tokens, num_experts].

Return type:

torch.Tensor

core.transformer.moe.moe_utils.topk_routing_with_score_function(
logits: torch.Tensor,
topk: int,
use_pre_softmax: bool = False,
num_groups: Optional[int] = None,
group_topk: Optional[int] = None,
scaling_factor: Optional[float] = None,
score_function: str = 'softmax',
expert_bias: Optional[torch.Tensor] = None,
fused: bool = False,
)#

Compute the routing probabilities and map for top-k selection with score function.

Parameters:
  • logits (torch.Tensor) – Logits tensor.

  • topk (int) – The number of experts to select for each token.

  • use_pre_softmax (bool) – Whether to apply softmax or sigmoid before top-k selection.

  • num_groups (int) – Number of groups for routed experts.

  • group_topk (int) – Number of selected groups for each token.

  • scaling_factor (float) – Scaling factor of routing score in top-k selection.

  • score_function (str) – The score function to use. Can be either “softmax” or “sigmoid”.

  • expert_bias (torch.Tensor) – The bias added to logits for expert routing.

Returns:

 - routing_probs (torch.Tensor): A tensor of shape [num_tokens, num_experts] containing
   the routing probabilities for each token to each expert.
 - routing_map (torch.Tensor): A mask tensor of shape [num_tokens, num_experts]
   indicating which experts were selected for each token. True values represent
   the selected experts.

Return type:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor]

core.transformer.moe.moe_utils.compute_routing_scores_for_aux_loss(
logits: torch.Tensor,
topk: int,
score_function: str,
fused: bool = False,
)#

Compute routing scores based on the score function.

Parameters:

logits (torch.Tensor) – The logits tensor after gating, shape: [num_tokens, num_experts].

Returns:

The normalized routing scores.

Return type:

torch.Tensor

core.transformer.moe.moe_utils.apply_router_token_dropping(
routing_probs: torch.Tensor,
routing_map: torch.Tensor,
router_topk: int,
capacity_factor: float,
drop_policy: str = 'probs',
pad_to_capacity: bool = False,
)#

Apply token dropping to top-k expert selection.

This function enforces expert capacity limits by dropping tokens that exceed the capacity and optionally padding to capacity.

Parameters:
  • routing_probs (torch.Tensor) – Tensor of shape [num_tokens, num_experts] containing the routing probabilities for selected experts.

  • routing_map (torch.Tensor) – Boolean tensor of shape [num_tokens, num_experts] indicating which experts were selected for each token.

  • router_topk (int) – Number of experts selected per token.

  • capacity_factor (float) – The capacity factor of each expert.

  • drop_policy (str) – Policy to drop tokens - “probs” or “position”.

  • pad_to_capacity (bool) – Whether to pad to capacity.

Returns:

  • final_probs: Routing probabilities after applying capacity constraints

  • final_map: Boolean mask after applying capacity constraints

Return type:

Tuple[torch.Tensor, torch.Tensor]

core.transformer.moe.moe_utils.save_to_aux_losses_tracker(
name: str,
loss: torch.Tensor,
layer_number: int,
num_layers: int,
reduce_group: torch.distributed.ProcessGroup = None,
avg_group: torch.distributed.ProcessGroup = None,
reduce_group_has_dp: bool = False,
)#

Save the auxiliary loss for logging.

Parameters:
  • name (str) – The name of the loss.

  • loss (torch.Tensor) – The loss tensor.

  • layer_number (int) – Layer index of the loss.

  • num_layers (int) – The number of total layers.

  • reduce_group (torch.distributed.ProcessGroup) – The group for reducing the loss.

  • avg_group (torch.distributed.ProcessGroup) – The group for averaging the loss.

  • reduce_group_has_dp (bool) – Whether the reduce group has data parallel ranks. Set this to True if the reduce group has data parallel ranks. This flag is used to ensure the correct reduction in aux loss tracking.

core.transformer.moe.moe_utils.clear_aux_losses_tracker()#

Clear the auxiliary losses.

core.transformer.moe.moe_utils.reduce_aux_losses_tracker_across_ranks(
track_names: Optional[List[str]] = None,
)#

Collect and reduce the auxiliary losses across ranks.

core.transformer.moe.moe_utils.track_moe_metrics(
loss_scale: float,
iteration: int,
writer,
wandb_writer=None,
total_loss_dict=None,
per_layer_logging=False,
force_initialize: bool = False,
track_names: Optional[List[str]] = None,
num_layers: Optional[int] = None,
moe_layer_freq: Optional[Union[int, List[int]]] = None,
mtp_num_layers: Optional[int] = None,
)#

Track the MoE metrics for logging.

core.transformer.moe.moe_utils.get_updated_expert_bias(
tokens_per_expert,
expert_bias,
expert_bias_update_rate,
)#

Update expert bias for biased expert routing. See https://arxiv.org/abs/2408.15664v1#

Parameters:
  • tokens_per_expert (torch.Tensor) – The number of tokens assigned to each expert.

  • expert_bias (torch.Tensor) – The bias for each expert.

  • expert_bias_udpate_rate (float) – The update rate for the expert bias.

core.transformer.moe.moe_utils.maybe_move_tensor_to_cpu(tensor, as_numpy=False, record_stream=False)#

Move a tensor to CPU if it is on GPU.

Parameters:
  • tensor (torch.Tensor or None) – The tensor to move to CPU.

  • as_numpy (bool) – Whether to convert the tensor to a numpy array.

  • record_stream (bool) – Whether to record the stream of the tensor, to prevent memory leak when the DtoH data transfer is on a side stream.

core.transformer.moe.moe_utils.get_moe_layer_wise_logging_tracker()#

Return the moe layer wise tracker.

class core.transformer.moe.moe_utils.RandomSTE#

Bases: torch.autograd.Function

Straight-Through Estimator(STE) function that returns random values with different seed for each rank.

This is used to generate random logits of router for load-balanced benchmark.

generator#

None

random_logits#

None

static forward(ctx, logits)#

Forward pass returns random logits with rank-specific seed.

static backward(ctx, grad_output)#

Backward pass propagates the gradient for logits.

core.transformer.moe.moe_utils.apply_random_logits(logits)#

Apply the RandomSTE function to the logits.

class core.transformer.moe.moe_utils.RouterGatingLinearFunction#

Bases: torch.autograd.Function

Autograd function for router gating linear.

static forward(
ctx,
inp: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
router_dtype: torch.dtype,
)#

Forward pass of the RouterGatingLinearFunction function.

static backward(ctx, grad_output: torch.Tensor)#

Backward pass of the RouterGatingLinearFunction function.

core.transformer.moe.moe_utils.router_gating_linear(
inp: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
router_dtype: torch.dtype,
)#

Customized linear layer for router gating. This linear layer accepts bfloat16 input and weight, and can return output with router_dtype. It can reduce the memory usage by avoiding saving the intermediate high precision tensors.

core.transformer.moe.moe_utils.get_align_size_for_quantization(
config: megatron.core.transformer.transformer_config.TransformerConfig,
)#

Get the alignment size for quantization.

core.transformer.moe.moe_utils.get_default_pg_collection()#

Get the default process groups for MoE.

Returns:

The default process groups for MoE.

Return type:

ProcessGroupCollection