core.transformer.moe.moe_utils#
Module Contents#
Classes#
An AutoScaler that triggers the backward pass and scales the grad for auxiliary loss. |
|
Straight-Through Estimator(STE) function that returns random values with different seed for each rank. |
|
Autograd function for router gating linear. |
Functions#
Calculate the auxiliary loss for load balancing. Refer to the Switch Transformer (https://arxiv.org/abs/2101.03961) and Global Load Balancing Loss(https://arxiv.org/abs/2501.11873) for details. |
|
Encourages the router’s logits to remain small to enhance stability. Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details. |
|
Sinkhorn based MoE routing function |
|
Calculate the capacity of each expert. |
|
Permute the tokens and probs based on the mask. Tokens with the same designated expert will be grouped together. The shape of mask is [tokens, num_experts], it indicates which experts were selected by each token. |
|
Restore the original order of tokens after permutation. If probs are provided, it will also apply them to the tokens before restoring the order. |
|
Split and sort the input tensor based on the split_sizes and sorted indices. |
|
Perform top-k routing on a subset of expert groups. |
|
Pad the routing map to ensure each expert has a multiple of pad_multiple tokens. |
|
Compute the routing probabilities and map for top-k selection with score function. |
|
Compute routing scores based on the score function. |
|
Apply token dropping to top-k expert selection. |
|
Save the auxiliary loss for logging. |
|
Clear the auxiliary losses. |
|
Collect and reduce the auxiliary losses across ranks. |
|
Track the MoE metrics for logging. |
|
Update expert bias for biased expert routing. See https://arxiv.org/abs/2408.15664v1# |
|
Move a tensor to CPU if it is on GPU. |
|
Return the moe layer wise tracker. |
|
Apply the RandomSTE function to the logits. |
|
Customized linear layer for router gating. This linear layer accepts bfloat16 input and weight, and can return output with router_dtype. It can reduce the memory usage by avoiding saving the intermediate high precision tensors. |
|
Get the alignment size for quantization. |
|
Get the default process groups for MoE. |
Data#
API#
- core.transformer.moe.moe_utils._MOE_LAYER_WISE_LOGGING_TRACKER#
None
- core.transformer.moe.moe_utils.switch_load_balancing_loss_func(
- probs: torch.Tensor,
- tokens_per_expert: torch.Tensor,
- total_num_tokens: int,
- topk: int,
- num_experts: int,
- moe_aux_loss_coeff: float,
- fused: bool = False,
Calculate the auxiliary loss for load balancing. Refer to the Switch Transformer (https://arxiv.org/abs/2101.03961) and Global Load Balancing Loss(https://arxiv.org/abs/2501.11873) for details.
Detailed explanation of the auxiliary loss#
The formula for the auxiliary loss is: loss = E * Σ_{i=1}^{E} (f_i * P_i) where: f_i = 1 / (T * topk) * Σ_{x∈B} routing_map(x, i) (fraction of tokens dispatched to expert i) P_i = 1 / T * Σ_{x∈B} probs(x, i) (averaged router probability allocated for expert i) E is the number of experts T is the total number of tokens in the batch B
For distributed training with sequence or context parallelism, each rank can process a subset of the batch. loss = E * Σ_{i=1}^{E} (f_i * Σ_{j=1}^{N} P_ij) = E * Σ_{i=1}^{E} Σ_{j=1}^{N} (f_i * P_ij) = Σ_{j=1}^{N} E * (Σ_{i=1}^{E} f_i * P_ij)
where: f_i = 1 / (T * topk) * Σ_{x∈B} routing_map(x, i) (fraction of tokens dispatched to expert i in the global batch) P_ij = 1 / T * Σ_{x∈B_j} probs(x, i) (averaged router probability allocated for expert i in local batch of the j-th rank) N is the number of ranks B_j is the batch of tokens in the j-th rank T is the total number of tokens in the global batch B
Note: To calculate the auxiliary loss at different levels (micro-batch or global batch):
probs: Should always be from the local batch being processed
tokens_per_expert: Should represent token counts at the desired level (either micro-batch or global batch)
total_num_tokens: Should match the total token count at the same level as tokens_per_expert
#########################################################
- param probs:
Softmax probabilities output by the router for each token. Shape in [num_tokens, num_experts].
- type probs:
torch.Tensor
- param tokens_per_expert:
Number of tokens assigned to each expert in the batch. Shape in [num_experts]
- type tokens_per_expert:
torch.Tensor
- param total_num_tokens:
Total number of tokens in the batch.
- type total_num_tokens:
int
- param topk:
The number of experts selected for each token.
- type topk:
int
- param num_experts:
The number of experts.
- type num_experts:
int
- param moe_aux_loss_coeff:
The coefficient for the auxiliary loss.
- type moe_aux_loss_coeff:
float
- returns:
The auxiliary loss for load balancing.
- rtype:
torch.Tensor
- core.transformer.moe.moe_utils.z_loss_func(logits, z_loss_coeff)#
Encourages the router’s logits to remain small to enhance stability. Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details.
- Parameters:
logits (torch.Tensor) – The logits of the router.
- Returns:
The logits after applying the z-loss.
- Return type:
torch.Tensor
- core.transformer.moe.moe_utils.sinkhorn(cost: torch.Tensor, tol: float = 0.0001)#
Sinkhorn based MoE routing function
- core.transformer.moe.moe_utils.get_capacity(
- num_tokens: int,
- num_experts: int,
- capacity_factor: float,
- min_capacity=None,
Calculate the capacity of each expert.
- Parameters:
num_tokens (int) – num of the input tokens.
num_experts (int) – num of the experts.
capacity_factor (float) – Capacity factor.
min_capacity (int, optional) – Minimum capacity. Defaults to None.
- Returns:
Capacity of each expert.
- Return type:
Tensor
- class core.transformer.moe.moe_utils.MoEAuxLossAutoScaler#
Bases:
torch.autograd.FunctionAn AutoScaler that triggers the backward pass and scales the grad for auxiliary loss.
- main_loss_backward_scale: torch.Tensor#
None
- static forward(ctx, output: torch.Tensor, aux_loss: torch.Tensor)#
Preserve the aux_loss by storing it in the context to avoid garbage collection.
- Parameters:
output (torch.Tensor) – The output tensor.
aux_loss (torch.Tensor) – The auxiliary loss tensor.
- Returns:
The output tensor.
- Return type:
torch.Tensor
- static backward(ctx, grad_output: torch.Tensor)#
Compute and scale the gradient for auxiliary loss..
- Parameters:
grad_output (torch.Tensor) – The gradient of the output.
- Returns:
The gradient of the output, scaled auxiliary loss gradient.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- static set_loss_scale(scale: torch.Tensor)#
set the scale of the aux loss.
- Parameters:
scale (torch.Tensor) – The scale value to set. Please ensure that the scale passed in matches the scale of the main_loss.
- core.transformer.moe.moe_utils.permute(
- tokens,
- routing_map,
- probs: Optional[torch.Tensor] = None,
- num_out_tokens: Optional[int] = None,
- fused: bool = False,
- drop_and_pad: bool = False,
Permute the tokens and probs based on the mask. Tokens with the same designated expert will be grouped together. The shape of mask is [tokens, num_experts], it indicates which experts were selected by each token.
When drop_and_pad=True, in routing_map, the number of non-zeros in each column equals to expert capacity. This function exploits this feature to use ops that support cuda graph.
- Parameters:
tokens (torch.Tensor) – The input token tensor, [num_tokens, hidden].
routing_map (torch.Tensor) – The sparse token to expert mapping, [num_tokens, num_experts].
probs (torch.Tensor, optional) – The probs tensor, [num_tokens, num_experts].
num_out_tokens (int, optional) – The number of output tokens. If None, it’s set to the number of input tokens.
fused (bool, optional) – Whether use the fused permute function.
drop_and_pad (bool, optional) – Whether or not the token dispatcher uses token-drop and pads the number of tokens to the expert capacity. If set to true, routing_map has a fixed number of non-zeros in each column.
- core.transformer.moe.moe_utils.unpermute(
- permuted_tokens: torch.Tensor,
- sorted_indices: torch.Tensor,
- restore_shape: torch.Size,
- probs: torch.Tensor = None,
- routing_map: torch.Tensor = None,
- fused: bool = False,
- drop_and_pad: bool = False,
Restore the original order of tokens after permutation. If probs are provided, it will also apply them to the tokens before restoring the order.
When drop_and_pad=True, the tensors will have the following properties:
In routing_map, the number of non-zeros in each column equals to expert capacity
The size of sorted_indices equals to num_experts * capacity, each split of
capacitycontains the indices of tokens routed to an expert. This function exploits these features to use ops that support cuda graph.
- Parameters:
permuted_tokens (torch.Tensor) – The permuted token tensor.
sorted_indices (torch.Tensor) – The indices used to sort the tokens.
restore_shape (torch.Size) – The shape of the unpermuted tensor.
probs (torch.Tensor, optional) – The unpermuted probs tensor,
routing_map (torch.Tensor, optional) – Token to expert mapping, shape [num_tokens, num_experts].
fused (bool, optional) – Whether use the fused unpermute function.
drop_and_pad (bool, optional) – Whether or not the token dispatcher uses token-drop and pads the number of tokens to the expert capacity.
- Returns:
The tokens restored to their original order.
- Return type:
torch.Tensor
- core.transformer.moe.moe_utils.sort_chunks_by_idxs(
- input: torch.Tensor,
- split_sizes: torch.Tensor,
- sorted_idxs: torch.Tensor,
- probs: Optional[torch.Tensor] = None,
- fused: bool = False,
Split and sort the input tensor based on the split_sizes and sorted indices.
- core.transformer.moe.moe_utils.group_limited_topk(
- scores: torch.Tensor,
- topk: int,
- num_tokens: int,
- num_experts: int,
- num_groups: int,
- group_topk: int,
Perform top-k routing on a subset of expert groups.
When using group-limited routing:
Experts are divided into ‘moe_router_num_groups’ equal-sized groups
For each token, ‘moe_router_group_topk’ groups are selected based on routing scores (specifically, the sum of top-2 expert scores within each group)
From these selected groups, ‘moe_router_topk’ individual experts are chosen
Two common use cases:
Device-limited routing: Set ‘moe_router_num_groups’ equal to expert parallel size (EP) to limit each token to experts on a subset of devices (See DeepSeek-V2: https://arxiv.org/pdf/2405.04434)
Node-limited routing: Set ‘moe_router_num_groups’ equal to number of nodes in EP group to limit each token to experts on a subset of nodes (See DeepSeek-V3: https://arxiv.org/pdf/2412.19437)
- Parameters:
scores (torch.Tensor) – Softmax scores generated by the router.
topk (int) – The number of experts to select for each token.
num_tokens (int) – The number of tokens.
num_experts (int) – The number of experts.
num_groups (int) – Number of groups for routed experts.
group_topk (int) – Number of groups selected for each token.
- Returns:
Probs and indices tensor.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- core.transformer.moe.moe_utils.pad_routing_map(
- routing_map: torch.Tensor,
- pad_multiple: int,
Pad the routing map to ensure each expert has a multiple of pad_multiple tokens.
This function ensures that each expert has a number of tokens that is a multiple of pad_multiple by converting some 0s to 1s in the routing map. The padding is done by selecting the first N zero elements in each row, where N is the number needed to reach the next multiple of pad_multiple.
- Parameters:
routing_map (torch.Tensor) – A boolean or integer tensor of shape [num_tokens, num_experts] indicating which tokens are routed to which experts.
pad_multiple (int) – The multiple to pad each expert’s token count to.
- Returns:
The padded routing map of shape [num_tokens, num_experts].
- Return type:
torch.Tensor
- core.transformer.moe.moe_utils.topk_routing_with_score_function(
- logits: torch.Tensor,
- topk: int,
- use_pre_softmax: bool = False,
- num_groups: Optional[int] = None,
- group_topk: Optional[int] = None,
- scaling_factor: Optional[float] = None,
- score_function: str = 'softmax',
- expert_bias: Optional[torch.Tensor] = None,
- fused: bool = False,
Compute the routing probabilities and map for top-k selection with score function.
- Parameters:
logits (torch.Tensor) – Logits tensor.
topk (int) – The number of experts to select for each token.
use_pre_softmax (bool) – Whether to apply softmax or sigmoid before top-k selection.
num_groups (int) – Number of groups for routed experts.
group_topk (int) – Number of selected groups for each token.
scaling_factor (float) – Scaling factor of routing score in top-k selection.
score_function (str) – The score function to use. Can be either “softmax” or “sigmoid”.
expert_bias (torch.Tensor) – The bias added to logits for expert routing.
- Returns:
- routing_probs (torch.Tensor): A tensor of shape [num_tokens, num_experts] containing the routing probabilities for each token to each expert. - routing_map (torch.Tensor): A mask tensor of shape [num_tokens, num_experts] indicating which experts were selected for each token. True values represent the selected experts.
- Return type:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
- core.transformer.moe.moe_utils.compute_routing_scores_for_aux_loss(
- logits: torch.Tensor,
- topk: int,
- score_function: str,
- fused: bool = False,
Compute routing scores based on the score function.
- Parameters:
logits (torch.Tensor) – The logits tensor after gating, shape: [num_tokens, num_experts].
- Returns:
The normalized routing scores.
- Return type:
torch.Tensor
- core.transformer.moe.moe_utils.apply_router_token_dropping(
- routing_probs: torch.Tensor,
- routing_map: torch.Tensor,
- router_topk: int,
- capacity_factor: float,
- drop_policy: str = 'probs',
- pad_to_capacity: bool = False,
Apply token dropping to top-k expert selection.
This function enforces expert capacity limits by dropping tokens that exceed the capacity and optionally padding to capacity.
- Parameters:
routing_probs (torch.Tensor) – Tensor of shape [num_tokens, num_experts] containing the routing probabilities for selected experts.
routing_map (torch.Tensor) – Boolean tensor of shape [num_tokens, num_experts] indicating which experts were selected for each token.
router_topk (int) – Number of experts selected per token.
capacity_factor (float) – The capacity factor of each expert.
drop_policy (str) – Policy to drop tokens - “probs” or “position”.
pad_to_capacity (bool) – Whether to pad to capacity.
- Returns:
final_probs: Routing probabilities after applying capacity constraints
final_map: Boolean mask after applying capacity constraints
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- core.transformer.moe.moe_utils.save_to_aux_losses_tracker(
- name: str,
- loss: torch.Tensor,
- layer_number: int,
- num_layers: int,
- reduce_group: torch.distributed.ProcessGroup = None,
- avg_group: torch.distributed.ProcessGroup = None,
- reduce_group_has_dp: bool = False,
Save the auxiliary loss for logging.
- Parameters:
name (str) – The name of the loss.
loss (torch.Tensor) – The loss tensor.
layer_number (int) – Layer index of the loss.
num_layers (int) – The number of total layers.
reduce_group (torch.distributed.ProcessGroup) – The group for reducing the loss.
avg_group (torch.distributed.ProcessGroup) – The group for averaging the loss.
reduce_group_has_dp (bool) – Whether the reduce group has data parallel ranks. Set this to True if the reduce group has data parallel ranks. This flag is used to ensure the correct reduction in aux loss tracking.
- core.transformer.moe.moe_utils.clear_aux_losses_tracker()#
Clear the auxiliary losses.
- core.transformer.moe.moe_utils.reduce_aux_losses_tracker_across_ranks(
- track_names: Optional[List[str]] = None,
Collect and reduce the auxiliary losses across ranks.
- core.transformer.moe.moe_utils.track_moe_metrics(
- loss_scale: float,
- iteration: int,
- writer,
- wandb_writer=None,
- total_loss_dict=None,
- per_layer_logging=False,
- force_initialize: bool = False,
- track_names: Optional[List[str]] = None,
- num_layers: Optional[int] = None,
- moe_layer_freq: Optional[Union[int, List[int]]] = None,
- mtp_num_layers: Optional[int] = None,
Track the MoE metrics for logging.
- core.transformer.moe.moe_utils.get_updated_expert_bias(
- tokens_per_expert,
- expert_bias,
- expert_bias_update_rate,
Update expert bias for biased expert routing. See https://arxiv.org/abs/2408.15664v1#
- Parameters:
tokens_per_expert (torch.Tensor) – The number of tokens assigned to each expert.
expert_bias (torch.Tensor) – The bias for each expert.
expert_bias_udpate_rate (float) – The update rate for the expert bias.
- core.transformer.moe.moe_utils.maybe_move_tensor_to_cpu(tensor, as_numpy=False, record_stream=False)#
Move a tensor to CPU if it is on GPU.
- Parameters:
tensor (torch.Tensor or None) – The tensor to move to CPU.
as_numpy (bool) – Whether to convert the tensor to a numpy array.
record_stream (bool) – Whether to record the stream of the tensor, to prevent memory leak when the DtoH data transfer is on a side stream.
- core.transformer.moe.moe_utils.get_moe_layer_wise_logging_tracker()#
Return the moe layer wise tracker.
- class core.transformer.moe.moe_utils.RandomSTE#
Bases:
torch.autograd.FunctionStraight-Through Estimator(STE) function that returns random values with different seed for each rank.
This is used to generate random logits of router for load-balanced benchmark.
- generator#
None
- random_logits#
None
- static forward(ctx, logits)#
Forward pass returns random logits with rank-specific seed.
- static backward(ctx, grad_output)#
Backward pass propagates the gradient for logits.
- core.transformer.moe.moe_utils.apply_random_logits(logits)#
Apply the RandomSTE function to the logits.
- class core.transformer.moe.moe_utils.RouterGatingLinearFunction#
Bases:
torch.autograd.FunctionAutograd function for router gating linear.
- static forward(
- ctx,
- inp: torch.Tensor,
- weight: torch.Tensor,
- bias: torch.Tensor,
- router_dtype: torch.dtype,
Forward pass of the RouterGatingLinearFunction function.
- static backward(ctx, grad_output: torch.Tensor)#
Backward pass of the RouterGatingLinearFunction function.
- core.transformer.moe.moe_utils.router_gating_linear(
- inp: torch.Tensor,
- weight: torch.Tensor,
- bias: torch.Tensor,
- router_dtype: torch.dtype,
Customized linear layer for router gating. This linear layer accepts bfloat16 input and weight, and can return output with router_dtype. It can reduce the memory usage by avoiding saving the intermediate high precision tensors.
- core.transformer.moe.moe_utils.get_align_size_for_quantization(
- config: megatron.core.transformer.transformer_config.TransformerConfig,
Get the alignment size for quantization.
- core.transformer.moe.moe_utils.get_default_pg_collection()#
Get the default process groups for MoE.
- Returns:
The default process groups for MoE.
- Return type: