core.transformer.moe.router#

Module Contents#

Classes#

Router

Base Router class

TopKRouter

Route each token to the top-k experts.

API#

class core.transformer.moe.router.Router(
config: megatron.core.transformer.transformer_config.TransformerConfig,
pg_collection: Optional[megatron.core.transformer.moe.moe_utils.ProcessGroupCollection] = None,
)#

Bases: abc.ABC, megatron.core.transformer.module.MegatronModule

Base Router class

Initialization

Initialize the Router module.

Parameters:
reset_parameters()#

Reset the router parameters.

gating(input: torch.Tensor)#

Forward pass of the router gate.

Parameters:

input (torch.Tensor) – Input tensor.

Returns:

Logits tensor.

Return type:

torch.Tensor

abstractmethod routing(logits: torch.Tensor)#

Routing function.

Parameters:

logits (torch.Tensor) – Logits tensor.

Returns:

A tuple containing token assignment probabilities and mapping.

Return type:

Tuple[torch.Tensor, torch.Tensor]

abstractmethod forward(input: torch.Tensor)#

Forward pass of the router.

Parameters:

input (torch.Tensor) – Input tensor.

set_layer_number(layer_number: int)#

Set the layer number for the router.

class core.transformer.moe.router.TopKRouter(
config: megatron.core.transformer.transformer_config.TransformerConfig,
pg_collection: Optional[megatron.core.transformer.moe.moe_utils.ProcessGroupCollection] = None,
)#

Bases: core.transformer.moe.router.Router

Route each token to the top-k experts.

The workflow of TopKRouter is as follows: (1) Calculate the logits by the router gating network. (2) Calculate the routing probabilities and map for top-k selection with score function. (3) [Optional] Apply token dropping to top-k expert selection. (4) [Optional] Apply the auxiliary load balancing loss for the given scores and routing map.

Naming convention: logits: The output logits by the router gating network. scores: The scores after score function used to select the experts and calculate aux loss. probs: The topk weights used to combined the experts’ outputs. routing_map: The masked routing map between tokens and experts.

Initialization

Initialize the zero token dropping router.

Parameters:
_maintain_float32_expert_bias()#

Maintain the expert bias in float32.

When using bf16/fp16, the expert bias gets converted to lower precision in Float16Module. We keep it in float32 to avoid routing errors when updating the expert_bias.

sinkhorn_load_balancing(logits: torch.Tensor)#

Apply sinkhorn routing to the logits tensor.

Parameters:

logits (torch.Tensor) – The logits tensor.

Returns:

A tuple containing token assignment probabilities and mask.

Return type:

Tuple[torch.Tensor, torch.Tensor]

get_aux_loss_coeff(aux_loss_type: str) float#

Return the aux loss coeff for the given auxiliary loss type. If the auxiliary loss type is not found, return 0.0.

is_aux_loss_enabled() bool#

Check if the auxiliary loss is enabled.

_apply_aux_loss(
probs: torch.Tensor,
scores_for_aux_loss: torch.Tensor,
routing_map: torch.Tensor,
)#

Apply the auxiliary loss for the given scores and routing map.

_apply_seq_aux_loss(
probs: torch.Tensor,
scores_for_aux_loss: torch.Tensor,
routing_map: torch.Tensor,
seq_length: int,
bsz: int,
)#

Apply the sequence-level auxiliary loss for the given scores and routing map.

To calculate the sequence-level aux loss, we reshape the batch_size dimension to experts dimension. The resulted loss by switch_load_balancing_loss_func is equal to the sum of aux loss for each sequence in the batch. And then we divide the aux loss by the batch size to get averaged aux loss.

_apply_global_aux_loss(
probs: torch.Tensor,
scores_for_aux_loss: torch.Tensor,
routing_map: torch.Tensor,
)#

Apply the global auxiliary loss for the given scores and routing map.

attach_and_log_load_balancing_loss(
activation: torch.Tensor,
aux_loss_coeff: float,
aux_loss: torch.Tensor,
aux_loss_name: str,
reduce_group: torch.distributed.ProcessGroup,
reduce_group_has_dp: bool = False,
)#

Attach aux loss function to activation and add to logging.

Parameters:
  • activation (torch.Tensor) – The activation tensor to attach the loss to.

  • aux_loss_coeff (float) – The coefficient for the auxiliary loss.

  • aux_loss (torch.Tensor) – The auxiliary loss tensor.

  • aux_loss_name (str) – The name of the auxiliary loss for logging.

  • reduce_group (torch.distributed.ProcessGroup) – The group for reducing the loss.

  • reduce_group_has_dp (bool) – Whether the reduce group has data parallel ranks. Set this to True if the reduce group has data parallel ranks. This flag is used to ensure the correct reduction in aux loss tracking.

apply_z_loss(logits)#

Encourages the router’s logits to remain small to enhance stability. Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details.

Parameters:

logits (torch.Tensor) – The logits of the router.

Returns:

The logits after applying the z-loss.

Return type:

torch.Tensor

apply_input_jitter(input: torch.Tensor)#

Add noise to the input tensor. Refer to https://arxiv.org/abs/2101.03961.

Parameters:

input (Tensor) – Input tensor.

Returns:

Jittered input.

Return type:

Tensor

_apply_expert_bias(routing_map: torch.Tensor)#

Update expert bias and tokens_per_expert Prevent extra local tokens accumulation on evaluation or activation recomputation

routing(logits: torch.Tensor)#

Top-k routing function

Parameters:

logits (torch.Tensor) – Logits tensor after gating.

Returns:

The probabilities of token to experts assignment. routing_map (torch.Tensor): The mapping of token to experts assignment, with shape [num_tokens, num_experts].

Return type:

probs (torch.Tensor)

reset_global_aux_loss_tracker()#

Reset the global aux loss tracker.

forward(input: torch.Tensor)#

Forward pass of the router.

Parameters:

input (torch.Tensor) – Input tensor.

_load_from_state_dict(*args, **kwargs)#

Load the state dict of the router.

_save_to_state_dict(*args, **kwargs)#

Save the state dict of the router.