core.transformer.moe.token_dispatcher#
Module Contents#
Classes#
MoE Token Dispatcher |
|
AllGather Based Token dispatcher. Note that this allgather spans the communication domain of TP*EP: |
|
AlltoAll-based token dispatcher. |
|
A manager class to handle dispatch and combine processes for MoE models. |
|
A manager class to handle fused all-to-all communication processes for MoE models using HybridEP backend. See https://github.com/deepseek-ai/DeepEP/tree/hybrid-ep for more details. |
|
A manager class to handle fused all-to-all communication processes for MoE models using DeepEP backend. See https://github.com/deepseek-ai/deepep for more details. |
|
A flexible token dispatcher that abstracts the underlying tensor and expert parallelism. It uses a single communication group over all TP and EP ranks, making the dispatch logic independent of the specific parallelism strategy. |
Data#
API#
- core.transformer.moe.token_dispatcher.logger#
‘getLogger(…)’
- class core.transformer.moe.token_dispatcher.MoETokenDispatcher(
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- pg_collection: Optional[megatron.core.transformer.moe.moe_utils.ProcessGroupCollection] = None,
MoE Token Dispatcher
Initialization
Initialize the MoE Token Dispatcher.
- Parameters:
config (TransformerConfig) – Configuration for the MoE layer.
pg_collection (ProcessGroupCollection, optional) – Process groups for MoE operations.
- abstractmethod dispatch_preprocess(
- tokens: torch.Tensor,
- routing_map: torch.Tensor,
- probs: torch.Tensor,
Prepares tokens for dispatch without inter-device communication.
This method should handle all local computations like tensor rearrangement and metadata extraction before the main communication step.
.. note::
Try to avoid any communication here to enable optimal computation-communication overlapping when enabling communication overlap, since communications in the same stream runs sequentially and may get exposed.
- Parameters:
tokens (torch.Tensor) – Input tokens.
routing_map (torch.Tensor) – Token to expert mapping tensor.
probs (torch.Tensor) – The routing probability tensor, [num_tokens, num_experts].
- Returns:
A tuple of preprocessed tokens and probabilities.
- abstractmethod token_dispatch(hidden_states: torch.Tensor, probs: torch.Tensor)#
Dispatches tokens to expert devices using communication.
This method performs the main communication (e.g., All-to-All) to send tokens to the devices where their assigned experts reside.
- Parameters:
hidden_states (torch.Tensor) – Preprocessed hidden states to be dispatched.
probs (torch.Tensor) – Preprocessed probabilities for each token-expert pair.
- Returns:
A tuple of dispatched tokens and probabilities.
- abstractmethod dispatch_postprocess(hidden_states: torch.Tensor, probs: torch.Tensor)#
Performs local processing after token dispatch communication.
This method handles post-communication tasks like token reordering and preparing metadata for the expert forward pass.
.. note::
Try to avoid any communication here to enable optimal computation-communication overlapping when enabling communication overlap, since communications in the same stream runs sequentially and may get exposed.
- Parameters:
hidden_states (torch.Tensor) – Dispatched hidden states.
probs (torch.Tensor) – Dispatched probabilities.
- Returns:
A tuple containing the permuted tokens for experts, the number of tokens per expert, and the permuted probabilities.
- abstractmethod combine_preprocess(hidden_states)#
Prepares expert outputs for the combine step.
This method performs local computations on expert outputs before the communication step for combining them.
.. note::
Try to avoid any communication here to enable optimal computation-communication overlapping when enabling communication overlap, since communications in the same stream runs sequentially and may get exposed.
- Parameters:
hidden_states (torch.Tensor) – The output tensor from the experts.
- Returns:
The preprocessed expert output.
- abstractmethod token_combine(hidden_states)#
Combines expert outputs across devices using communication.
This method aggregates expert outputs from different devices via communication (e.g., All-to-All or Reduce-Scatter).
- Parameters:
hidden_states (torch.Tensor) – Preprocessed output from experts.
- Returns:
The combined expert outputs.
- abstractmethod combine_postprocess(hidden_states)#
Performs local processing after token combine.
This method handles post-communication tasks like unpermuting and reshaping to restore the original tensor structure.
.. note::
Try to avoid any communication here to enable optimal computation-communication overlapping when enabling communication overlap, since communications in the same stream runs sequentially and may get exposed.
- Parameters:
hidden_states (torch.Tensor) – Combined hidden states from token combination
- Returns:
The final output tensor.
Set shared expert to the dispatcher.
- class core.transformer.moe.token_dispatcher.MoEAllGatherTokenDispatcher(
- num_local_experts: int,
- local_expert_indices: List[int],
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- pg_collection: Optional[megatron.core.transformer.moe.moe_utils.ProcessGroupCollection] = None,
Bases:
core.transformer.moe.token_dispatcher.MoETokenDispatcherAllGather Based Token dispatcher. Note that this allgather spans the communication domain of TP*EP:
Initialization
Initialize the AllGather based token dispatcher.
- Parameters:
num_local_experts (int) – Number of local experts.
local_expert_indices (List[int]) – Indices of local experts.
config (TransformerConfig) – Configuration for the MoE layer.
pg_collection (ProcessGroupCollection, optional) – Process groups for MoE operations.
- dispatch_preprocess(
- hidden_states: torch.Tensor,
- routing_map: torch.Tensor,
- probs: torch.Tensor,
Reshapes hidden states and caches the routing map.
- token_dispatch(hidden_states, probs)#
Gathers tokens from all TP*EP ranks using AllGather.
- dispatch_postprocess(hidden_states, probs)#
After gathering in token_dispatch, this method identifies tokens for local experts and permutes them for expert processing.
- combine_preprocess(hidden_states)#
Reverses token permutation to restore original ordering before reduction operations.
This method unpermutes the expert outputs using the cached permutation mapping from the dispatch phase. The unpermutation operation restores tokens to their original sequence positions, preparing them for the subsequent reduction scatter operation that will aggregate results across ranks.
- token_combine(hidden_states)#
Combines expert outputs using Reduce-Scatter.
This method performs the ReduceScatter communication operation to collect expert outputs from their processing ranks and redistribute tokens back to the ranks that originally held them. This completes the expert processing communication pattern and prepares tokens for final unpermutation.
- combine_postprocess(hidden_states)#
Restores the original tensor shape.
- class core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher(
- num_local_experts: int,
- local_expert_indices: List[int],
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- pg_collection: Optional[megatron.core.transformer.moe.moe_utils.ProcessGroupCollection] = None,
Bases:
core.transformer.moe.token_dispatcher.MoETokenDispatcherAlltoAll-based token dispatcher.
The workflow of AlltoAll token dispatcher is as follows: (1) preprocess: calculate necessary metadata for communication and permute (2) dispatch process: permute tokens (3) token dispatch: A2A(EP) (4) dispatch postprocess: AG(TP)->sort_chunk(if num_local_experts>1) (5) combine preprocess: sort_chunk(if num_local_experts>1)->RS(TP) (6) token combine: A2A(EP) (7) combine postprocess: unpermute tokens
Initialization
Initialize the AlltoAll token dispatcher.
- Parameters:
num_local_experts (int) – Number of local experts on the current device.
local_expert_indices (List[int]) – Indices of local experts on the current device.
config (TransformerConfig) – Configuration for the transformer model.
pg_collection (ProcessGroupCollection, optional) – Process groups for MoE operations.
- cuda_dtoh_stream#
None
- preprocess(routing_map: torch.Tensor) torch.Tensor#
Preprocesses the token routing map for All-to-All communication and token permutation.
This method computes the number of tokens assigned to each expert based on the routing_map. It also initializes necessary data structures for All-to-All communication, such as input and output splits, and the mapping between global tokens and local experts. This method should not call any DtoH data copying due to performance consideration. The necessary DtoH copies are made on the
self.cuda_dtoh_streamatself.cuda_dtoh_point.- Parameters:
routing_map (torch.Tensor) – The mapping of tokens to experts.
- Returns:
A tensor with the number of tokens for each local expert.
- dispatch_preprocess(
- hidden_states: torch.Tensor,
- routing_map: torch.Tensor,
- probs: torch.Tensor,
Prepares hidden states and probabilities for dispatch.
This method reshapes the hidden states, computes communication metadata, and permutes the tokens and probabilities before the All-to-All communication.
- Parameters:
hidden_states (torch.Tensor) – Input token embeddings.
routing_map (torch.Tensor) – The mapping of tokens to experts.
probs (torch.Tensor) – Routing probabilities.
- Returns:
A tuple of permuted hidden states and probabilities.
- token_dispatch(permutated_local_input_tokens, permuted_probs)#
Perform all-to-all communication for dispatching tokens.
This method performs the all-to-all communication step to dispatch tokens across expert parallel ranks. It synchronizes metadata at the appropriate point before performing the communication.
- Parameters:
permutated_local_input_tokens (torch.Tensor) – Pre-permuted input tokens.
permuted_probs (torch.Tensor) – Pre-permuted probabilities.
- Returns:
A tuple of tokens and probabilities after All-to-All.
- dispatch_postprocess(global_input_tokens, global_probs)#
Post-processes tokens after All-to-All communication.
This involves an All-Gather in the tensor parallel dimension and sorting tokens by expert if there are multiple local experts.
- Parameters:
global_input_tokens (torch.Tensor) – Tokens after All-to-All.
global_probs (torch.Tensor) – Probabilities after All-to-All.
- Returns:
A tuple of processed tokens, token counts per expert, and processed probabilities.
- combine_preprocess(hidden_states)#
Prepares hidden states for token combination after expert computations.
This may involve un-sorting tokens and a Reduce-Scatter in the tensor parallel dimension.
- token_combine(
- hidden_states: torch.Tensor,
- async_finish: bool = True,
- allocate_on_comm_stream: bool = True,
Executes fused un-permutation and communication using DeepEP kernels.
This method performs the inverse AlltoAll communication operation to collect expert outputs from their processing ranks and redistribute them back to the ranks that originally held the corresponding tokens. This completes the expert processing communication pattern and prepares tokens for final unpermutation.
- Parameters:
hidden_states (torch.Tensor) – Expert outputs ready for combination
async_finish (bool) – Whether to use asynchronous communication completion
allocate_on_comm_stream (bool) – Whether to allocate buffers on communication stream
- Returns:
Tokens after the All-to-All communication for combining.
- combine_postprocess(permutated_local_input_tokens)#
Finalizes token reconstruction with un-permutation and reshaping.
This method un-permutes the tokens back to their original order, reshapes the tensor to its original shape, and adds the shared expert output if enabled.
- Parameters:
permutated_local_input_tokens (torch.Tensor) – Permuted hidden states from token combine.
- Returns:
The final MoE layer output reshaped to its original dimensions.
- _maybe_update_cuda_sync_point(point: str)#
Update the CUDA sync point if the priority of the new point is higher than the current sync point, which means the new point is reached earlier than the current sync point.
- _maybe_dtoh_and_synchronize(
- point: str,
- tokens_per_expert: torch.Tensor = None,
Move all possible GPU tensors to CPU and make a synchronization at the expected point.
- class core.transformer.moe.token_dispatcher._DispatchManager#
Bases:
abc.ABCA manager class to handle dispatch and combine processes for MoE models.
DispatcherManager handles token dispatching according to the routing_map of format [num_local_tokens, world_size, num_instances]. The routing_map is a 3D tensor where each element indicates whether a token should be sent to a specific rank.
num_instances is the maximum number of tokens instances dispatched into a target rank, it can be the number of local experts, or the size of sub_group.
- abstractmethod setup_metadata(routing_map: torch.Tensor, probs: torch.Tensor)#
Set up metadata of routing_map and probs.
- abstractmethod dispatch(hidden_states: torch.Tensor) torch.Tensor#
Dispatch the hidden_states according to the routing_map.
- abstractmethod combine(hidden_states: torch.Tensor) torch.Tensor#
Combine the hidden_states after expert processing.
- hidden_states: torch.Tensor,
Get the permuted hidden states by instances.
- hidden_states: torch.Tensor,
Get the restored hidden states by instances.
- class core.transformer.moe.token_dispatcher._HybridEPManager(
- group: torch.distributed.ProcessGroup,
- num_local_experts: int,
- num_experts: int,
- config: megatron.core.transformer.transformer_config.TransformerConfig,
Bases:
core.transformer.moe.token_dispatcher._DispatchManagerA manager class to handle fused all-to-all communication processes for MoE models using HybridEP backend. See https://github.com/deepseek-ai/DeepEP/tree/hybrid-ep for more details.
The workflow of the HybridEP dispatcher is: (1) setup_metadata(): Process routing map and probabilities to prepare dispatch metadata (2) dispatch(): - Permute tokens for communication, perform all-to-all communication, and permute tokens for experts in single step (3) combine(): - Unpermute tokens for communication, perform all-to-all communication, and unpermute tokens for attention in single step
Initialization
Initialize the HybridEP dispatcher.
- Parameters:
group (torch.distributed.ProcessGroup) – The process group to use for communication. This should be the ETPxEP group.
num_local_experts (int) – The number of local experts.
num_experts (int) – The total number of experts in the group.
config (TransformerConfig) – The configuration for the transformer model.
- setup_metadata(routing_map: torch.Tensor, probs: torch.Tensor)#
- dispatch(
- hidden_states: torch.Tensor,
- async_finish: bool = True,
- allocate_on_comm_stream: bool = True,
- combine(
- hidden_states: torch.Tensor,
- async_finish: bool = True,
- allocate_on_comm_stream: bool = True,
- hidden_states: torch.Tensor,
- hidden_states: torch.Tensor,
- get_number_of_tokens_per_expert() torch.Tensor#
Get the number of tokens per expert.
- class core.transformer.moe.token_dispatcher._DeepepManager(
- group: torch.distributed.ProcessGroup,
- num_local_experts: int,
- router_topk: int,
- num_experts: int,
- config: megatron.core.transformer.transformer_config.TransformerConfig,
Bases:
core.transformer.moe.token_dispatcher._DispatchManagerA manager class to handle fused all-to-all communication processes for MoE models using DeepEP backend. See https://github.com/deepseek-ai/deepep for more details.
The workflow of the DeepEP dispatcher is: (1) setup_metadata(): Process routing map and probabilities to prepare dispatch metadata (2) dispatch(): - Use fused kernel to permute tokens and perform all-to-all communication in single step (3) get_permuted_hidden_states_by_instances(): - Convert routing map and probabilities to multihot format - Permute tokens using fused kernel (4) get_restored_hidden_states_by_instances(): - Reverse permutation using fused kernel (5) combine(): - Reverse process using fused kernel to unpermute and perform all-to-all in single step
This implementation uses fused communication kernels (fused_dispatch/fused_combine) that combine permutation and communication operations for improved efficiency compared to separate permute+alltoall steps.
Initialization
Initialize the DeepEP dispatcher.
- Parameters:
group (torch.distributed.ProcessGroup) – The process group to use for communication. This should be the ETPxEP group.
num_local_experts (int) – The number of local experts.
router_topk (int) – The number of experts for each token to select.
num_experts (int) – The total number of experts in the group.
config (TransformerConfig) – The configuration for the transformer model.
- setup_metadata(routing_map: torch.Tensor, probs: torch.Tensor)#
- dispatch(
- hidden_states: torch.Tensor,
- async_finish: bool = False,
- allocate_on_comm_stream: bool = False,
- _indices_to_multihot(indices, probs)#
Converts a tensor of indices to a multihot vector.
- Parameters:
indices (torch.Tensor) – [num_tokens, topk] token indices, where -1 means masked out.
probs (torch.Tensor) – [num_tokens, topk] token probabilities.
- Returns:
A tuple of (routing_map, probs), where routing_map is the multihot vector and probs is the multihot probabilities.
- get_number_of_tokens_per_expert() torch.Tensor#
Get the number of tokens per expert.
- combine(
- hidden_states: torch.Tensor,
- async_finish: bool = False,
- allocate_on_comm_stream: bool = False,
- _pad_routing_map(
- routing_map: torch.Tensor,
- tokens_per_expert: torch.Tensor,
Pad the routing map to the nearest multiple of the pad_multiple.
- hidden_states: torch.Tensor,
- hidden_states: torch.Tensor,
- class core.transformer.moe.token_dispatcher.MoEFlexTokenDispatcher(
- num_local_experts: int,
- local_expert_indices: List[int],
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- pg_collection: Optional[megatron.core.transformer.moe.moe_utils.ProcessGroupCollection] = None,
Bases:
core.transformer.moe.token_dispatcher.MoETokenDispatcherA flexible token dispatcher that abstracts the underlying tensor and expert parallelism. It uses a single communication group over all TP and EP ranks, making the dispatch logic independent of the specific parallelism strategy.
Initialization
Initialize the Flex token dispatcher.
- Parameters:
num_local_experts (int) – Number of local experts on the current device.
local_expert_indices (List[int]) – Indices of local experts on the current device.
config (TransformerConfig) – Configuration for the transformer model.
pg_collection (ProcessGroupCollection, optional) – Process groups for MoE operations.
- _initialize_metadata(
- routing_map: torch.Tensor,
- probs: torch.Tensor,
Initialize the routing map and probs to a unified format covering the TPxEP group. This design decouples the communication group from underlying model parallelism groups, such that the communication strategy of tokens can be agnostic of TP size and EP size.
This function expands the routing_map from shape [num_local_tokens, num_experts] to [num_local_tokens, world_size, num_local_experts]. Each element in the routing_map indicates whether a token should be sent to a specific rank. Specifically, the routing_map is replicated across TP group since each TP ranks in a TP group should receive the same tokens.
- dispatch_preprocess(
- hidden_states: torch.Tensor,
- routing_map: torch.Tensor,
- probs: torch.Tensor,
Initializes routing metadata and prepares tensors for fused dispatch.
This method reshapes input tensors and processes routing information into a unified format, where the routing map is expanded to cover the TPxEP communication domain, enabling the token dispatch logic to be agnostic to parallelism strategies.
- Parameters:
hidden_states (torch.Tensor) – Input hidden states to be processed
routing_map (torch.Tensor) – Map indicating which expert each token should be routed to
probs (torch.Tensor) – Routing probabilities for each token-expert pair
- Returns:
A tuple of reshaped hidden states and token probabilities.
- token_dispatch(
- hidden_states: torch.Tensor,
- probs: torch.Tensor = None,
- async_finish: bool = True,
- allocate_on_comm_stream: bool = True,
Execute fused permutation and AlltoAll communication.
This method currently leverages DeepEP’s fused dispatch kernel, which combines token permutation and AlltoAll communication into a single optimized operation. The fused approach reduces memory bandwidth requirements and enables better overlap between computation and communication operations.
- Parameters:
hidden_states (torch.Tensor) – Preprocessed hidden states to be dispatched
probs (torch.Tensor) – Routing probabilities (unused in current implementation)
async_finish (bool) – Whether to use asynchronous communication completion
allocate_on_comm_stream (bool) – Whether to allocate buffers on communication stream
- Returns:
A tuple of dispatched tokens and probabilities.
- dispatch_postprocess(hidden_states: torch.Tensor, probs: torch.Tensor)#
Converts dispatched tokens to a per-expert format for expert processing.
This method transforms the output of the fused dispatch into the tensor organization required for the expert computation.
- Parameters:
hidden_states (torch.Tensor) – Hidden states after fused dispatch
probs (torch.Tensor) – Routing probabilities after fused dispatch
- Returns:
A tuple of permuted tokens, token counts per expert, and permuted probabilities.
- combine_preprocess(hidden_states: torch.Tensor)#
Pre-processes hidden states before combining them after expert processing.
This method restores the hidden states to their original ordering before expert processing by using the communication manager’s restoration function.
- token_combine(
- hidden_states: torch.Tensor,
- async_finish: bool = True,
- allocate_on_comm_stream: bool = True,
Executes fused un-permutation and communication using DeepEP kernels.
This is the inverse of the
token_dispatchoperation.- Parameters:
hidden_states (torch.Tensor) – Expert outputs ready for combination
async_finish (bool) – Whether to use asynchronous communication completion
allocate_on_comm_stream (bool) – Whether to allocate buffers on communication stream
- Returns:
Combined tokens after fused un-permutation and communication.
- combine_postprocess(hidden_states: torch.Tensor)#
Restores the original tensor shape and finalizes the MoE layer output.
This method performs the final step of the MoE token processing pipeline by reshaping the combined tokens back to their original input dimensions.
- Parameters:
hidden_states (torch.Tensor) – Combined tokens.
- Returns:
The final MoE layer output reshaped to its original dimensions.