nemo_automodel.components.moe.megatron.token_dispatcher
nemo_automodel.components.moe.megatron.token_dispatcher
Module Contents
Classes
API
Flex token dispatcher supporting DeepEP, HybridEP, and UCCL-EP backends.
Initialize the routing map and probs to a unified format covering the TPxEP group. This design decouples the communication group from underlying model parallelism groups, such that the communication strategy of tokens can be agnostic of TP size and EP size.
Performs all-to-all communication to combine tokens after expert processing.
Post-processes the combined hidden states after all-to-all communication.
This method reshapes the combined hidden states to match the original input shape.
Pre-processes the hidden states before combining them after expert processing.
This method restores the hidden states to their original ordering before expert processing by using the communication manager’s restoration function.
Performs all-to-all communication to dispatch tokens across expert parallel ranks.
Post-processes the dispatched hidden states after all-to-all communication.
This method retrieves the permuted hidden states by experts, calculates the number of tokens per expert, and returns the processed data ready for expert processing.
Preprocesses the hidden states and routing information before dispatching tokens to experts. Args: hidden_states (torch.Tensor): Input hidden states to be processed num_local_tokens (int): Number of tokens to be processed probs (torch.Tensor): Routing probabilities for each token-expert pair
Returns:
Tuple containing:
Preprocesses the hidden states and routing information before dispatching tokens to experts.
For DeepEP backend: uses token_indices and token_probs directly. For HybridEP backend: converts token_indices to routing_map (multihot format).
Permutes tokens according to probs and dispatches them to experts.
This method implements the token permutation process in three steps:
- Preprocess the hidden states
- Perform all-to-all communication to dispatch tokens
- Post-process the dispatched tokens for expert processing
Permutes tokens according to probs and dispatches them to experts.
This method implements the token permutation process in three steps:
- Preprocess the hidden states
- Perform all-to-all communication to dispatch tokens
- Post-process the dispatched tokens for expert processing
Reverses the token permutation process to restore the original token order.
This method implements the token unpermutation process in three steps:
- Pre-process the hidden states to restore their original ordering
- Perform all-to-all communication to combine tokens
- Post-process the combined tokens to match the original input shape
Configuration for MoE token dispatch and combine backends.
Use asynchronous DeepEP/UCCL-EP dispatch and allocate dispatched tensors on the communication stream.
Number of SMs to use for DeepEP backend.
Enable DeepEP for efficient token dispatching and combine in MoE models.
moe_expert_capacity_factor (float): The capacity factor for each expert, None means no token will be dropped. The default is None.
Backend for the flex token dispatcher. Options: ‘deepep’, ‘hybridep’, or ‘uccl_ep’.
Number of SMs to use for HybridEP dispatch and combine APIs.
Fuse token rearrangement ops during token dispatching.
Data type for routing and expert output weighted averaging. Using fp32 or fp64 can improve stability especially when the number of experts is large (e.g. finegrained-moe). None means no changes for dtype.
Number of tokens to pad to a multiple of for each expert.
Number of experts to route to for each token.
Share one communication manager instance across MoE layers for the configured backend.
Number of experts to use for MoE layer. When set, it replaces MLP with MoE layer. Set to None for no MoE.
Bases: _DispatchManager
A manager class to handle fused all-to-all communication processes for MoE models using DeepEP backend. See https://github.com/deepseek-ai/deepep for more details.
The workflow of the DeepEP dispatcher is: (1) setup_metadata(): Process routing map and probabilities to prepare dispatch metadata (2) dispatch():
- Use fused kernel to permute tokens and perform all-to-all communication in single step (3) get_permuted_hidden_states_by_instances():
- Convert routing map and probabilities to multihot format
- Permute tokens using fused kernel (4) get_restored_hidden_states_by_instances():
- Reverse permutation using fused kernel (5) combine():
- Reverse process using fused kernel to unpermute and perform all-to-all in single step
This implementation uses fused communication kernels (fused_dispatch/fused_combine) that combine permutation and communication operations for improved efficiency compared to separate permute+alltoall steps.
Converts a tensor of indices to a multihot vector.
Parameters:
[num_tokens, topk] token indices, where -1 means masked out.
[num_tokens, topk] token probabilities.
Returns:
Tuple[torch.Tensor, torch.Tensor]:
- routing_map: Multihot vector.
- probs: Multihot probabilities.
Reverse process using fused kernel to unpermute and perform all-to-all in single step
Dispatch the hidden_states
Get the number of tokens per expert.
- Convert routing map and probabilities to multihot format
- Permute tokens using fused kernel
Restore the hidden states to their original ordering before expert processing
Process routing map and probabilities to prepare dispatch metadata
A manager class to handle dispatch and combine processes for MoE models.
DispatcherManager handles token dispatching according to the routing_map of format [num_local_tokens, world_size, num_instances]. The routing_map is a 3D tensor where each element indicates whether a token should be sent to a specific rank.
num_instances is the maximum number of tokens instances dispatched into a target rank, it can be the number of local experts, or the size of sub_group.
Combine the hidden_states after expert processing.
Dispatch the hidden_states according to the routing_map.
Get the metadata of the dispatched hidden_states.
Get the permuted hidden states by instances.
Get the restored hidden states by instances.
Set up metadata of routing_map and probs.
Bases: _DispatchManager
A manager class to handle fused all-to-all communication processes for MoE models using HybridEP backend. See https://github.com/deepseek-ai/DeepEP/tree/hybrid-ep for more details.
The workflow of the HybridEP dispatcher is: (1) setup_metadata(): Process routing map and probabilities to prepare dispatch metadata (2) dispatch():
- Permute tokens for communication, perform all-to-all communication, and permute tokens for experts in single step (3) combine():
- Unpermute tokens for communication, perform all-to-all communication, and unpermute tokens for attention in single step
Converts a tensor of indices to a multihot vector.
Process routing map and probabilities to prepare dispatch metadata.
Convert from topk indices format to multihot routing_map format.