nemo_automodel.components.moe.megatron.token_dispatcher
#
Module Contents#
Classes#
A manager class to handle dispatch and combine processes for MoE models. |
|
A manager class to handle fused all-to-all communication processes for MoE models using DeepEP backend. See https://github.com/deepseek-ai/deepep for more details. |
|
Flex token dispatcher using DeepEP. |
Data#
We use the following notation throughout this file: H: hidden size B: micro batch size S: sequence length TP: tensor model parallel size EP: expert model parallel size num_local_tokens: S/TPB num_global_tokens: num_local_tokensTP*EP |
API#
- nemo_automodel.components.moe.megatron.token_dispatcher.SHARING_DEEPEP_MANAGER#
True
We use the following notation throughout this file: H: hidden size B: micro batch size S: sequence length TP: tensor model parallel size EP: expert model parallel size num_local_tokens: S/TPB num_global_tokens: num_local_tokensTP*EP
- class nemo_automodel.components.moe.megatron.token_dispatcher._DispatchManager#
Bases:
abc.ABC
A manager class to handle dispatch and combine processes for MoE models.
DispatcherManager handles token dispatching according to the routing_map of format [num_local_tokens, world_size, num_instances]. The routing_map is a 3D tensor where each element indicates whether a token should be sent to a specific rank.
num_instances is the maximum number of tokens instances dispatched into a target rank, it can be the number of local experts, or the size of sub_group.
- abstractmethod setup_metadata(routing_map: torch.Tensor, probs: torch.Tensor)#
Set up metadata of routing_map and probs.
- abstractmethod dispatch(hidden_states: torch.Tensor) torch.Tensor #
Dispatch the hidden_states according to the routing_map.
- abstractmethod combine(hidden_states: torch.Tensor) torch.Tensor #
Combine the hidden_states after expert processing.
- abstractmethod get_dispached_metadata() torch.Tensor #
Get the metadata of the dispatched hidden_states.
- hidden_states: torch.Tensor,
Get the permuted hidden states by instances.
- hidden_states: torch.Tensor,
Get the restored hidden states by instances.
- class nemo_automodel.components.moe.megatron.token_dispatcher._DeepepManager(
- group: torch.distributed.ProcessGroup,
- router_topk: int,
- permute_fusion: bool = False,
- capacity_factor: Optional[float] = None,
- num_experts: Optional[int] = None,
- num_local_experts: Optional[int] = None,
- router_dtype: Optional[str] = None,
- moe_router_expert_pad_multiple: Optional[int] = None,
Bases:
nemo_automodel.components.moe.megatron.token_dispatcher._DispatchManager
A manager class to handle fused all-to-all communication processes for MoE models using DeepEP backend. See https://github.com/deepseek-ai/deepep for more details.
The workflow of the DeepEP dispatcher is: (1) setup_metadata(): Process routing map and probabilities to prepare dispatch metadata (2) dispatch(): - Use fused kernel to permute tokens and perform all-to-all communication in single step (3) get_permuted_hidden_states_by_instances(): - Convert routing map and probabilities to multihot format - Permute tokens using fused kernel (4) get_restored_hidden_states_by_instances(): - Reverse permutation using fused kernel (5) combine(): - Reverse process using fused kernel to unpermute and perform all-to-all in single step
This implementation uses fused communication kernels (fused_dispatch/fused_combine) that combine permutation and communication operations for improved efficiency compared to separate permute+alltoall steps.
Initialization
- setup_metadata(num_local_tokens: int, probs: torch.Tensor)#
Process routing map and probabilities to prepare dispatch metadata
- dispatch(
- hidden_states: torch.Tensor,
- async_finish: bool = False,
- allocate_on_comm_stream: bool = False,
Dispatch the hidden_states
- _indices_to_multihot(indices, probs)#
Converts a tensor of indices to a multihot vector.
- Parameters:
indices (torch.Tensor) β [num_tokens, topk] token indices, where -1 means masked out.
probs (torch.Tensor) β [num_tokens, topk] token probabilities.
- Returns:
routing_map: Multihot vector.
probs: Multihot probabilities.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- get_dispached_metadata() torch.Tensor #
- get_number_of_tokens_per_expert() torch.Tensor #
Get the number of tokens per expert.
- combine(
- hidden_states: torch.Tensor,
- async_finish: bool = False,
- allocate_on_comm_stream: bool = False,
Reverse process using fused kernel to unpermute and perform all-to-all in single step
- hidden_states: torch.Tensor,
Convert routing map and probabilities to multihot format
Permute tokens using fused kernel
- hidden_states: torch.Tensor,
Restore the hidden states to their original ordering before expert processing
- class nemo_automodel.components.moe.megatron.token_dispatcher.MoEConfig#
- moe_enable_deepep: bool#
True
Enable DeepEP for efficient token dispatching and combine in MoE models.
- moe_permute_fusion: bool#
False
Fuse token rearrangement ops during token dispatching.
- moe_expert_capacity_factor: Optional[float]#
None
moe_expert_capacity_factor (float): The capacity factor for each expert, None means no token will be dropped. The default is None.
- moe_router_topk: int#
2
Number of experts to route to for each token.
- moe_router_expert_pad_multiple: Optional[int]#
None
Number of tokens to pad to a multiple of for each expert.
- num_moe_experts: int#
64
Number of experts to use for MoE layer. When set, it replaces MLP with MoE layer. Set to None for no MoE.
- moe_router_dtype: str#
βfp32β
Data type for routing and expert output weighted averaging. Using fp32 or fp64 can improve stability especially when the number of experts is large (e.g. finegrained-moe). None means no changes for dtype.
- class nemo_automodel.components.moe.megatron.token_dispatcher.MoEFlexTokenDispatcher(
- num_local_experts: int,
- local_expert_indices: List[int],
- config: nemo_automodel.components.moe.megatron.token_dispatcher.MoEConfig,
- ep_group: torch.distributed.ProcessGroup,
Flex token dispatcher using DeepEP.
Initialization
Initialize the Flex token dispatcher.
- Parameters:
num_local_experts (int) β Number of local experts on the current device.
local_expert_indices (List[int]) β Indices of local experts on the current device.
config (MoEConfig) β Configuration for the transformer model.
group (torch.distributed.ProcessGroup) β Process group for MoE operations.
None
- _initialize_metadata(
- num_local_tokens: int,
- probs: torch.Tensor,
Initialize the routing map and probs to a unified format covering the TPxEP group. This design decouples the communication group from underlying model parallelism groups, such that the communication strategy of tokens can be agnostic of TP size and EP size.
- dispatch_preprocess2(
- hidden_states: torch.Tensor,
- num_local_tokens: int,
- token_probs: torch.Tensor,
- token_indices: torch.Tensor,
Preprocesses the hidden states and routing information before dispatching tokens to experts.
- dispatch_preprocess(
- hidden_states: torch.Tensor,
- num_local_tokens: int,
- probs: torch.Tensor,
Preprocesses the hidden states and routing information before dispatching tokens to experts.
- Parameters:
hidden_states (torch.Tensor) β Input hidden states to be processed
num_local_tokens (int) β Number of tokens to be processed
probs (torch.Tensor) β Routing probabilities for each token-expert pair
- Returns:
torch.Tensor: Reshaped hidden states
torch.Tensor: Token probabilities from the communication manager
None: Placeholder for compatibility
- Return type:
Tuple containing
- dispatch_all_to_all(
- hidden_states: torch.Tensor,
- probs: torch.Tensor = None,
- async_finish: bool = True,
- allocate_on_comm_stream: bool = True,
Performs all-to-all communication to dispatch tokens across expert parallel ranks.
- dispatch_postprocess(hidden_states: torch.Tensor)#
Post-processes the dispatched hidden states after all-to-all communication.
This method retrieves the permuted hidden states by experts, calculates the number of tokens per expert, and returns the processed data ready for expert processing.
- token_permutation(
- hidden_states: torch.Tensor,
- num_local_tokens: int,
- probs: torch.Tensor,
Permutes tokens according to probs and dispatches them to experts.
This method implements the token permutation process in three steps:
Preprocess the hidden states
Perform all-to-all communication to dispatch tokens
Post-process the dispatched tokens for expert processing
- token_permutation2(
- hidden_states: torch.Tensor,
- num_local_tokens: int,
- token_probs: torch.Tensor,
- token_indices: torch.Tensor,
Permutes tokens according to probs and dispatches them to experts.
This method implements the token permutation process in three steps:
Preprocess the hidden states
Perform all-to-all communication to dispatch tokens
Post-process the dispatched tokens for expert processing
- combine_preprocess(hidden_states: torch.Tensor)#
Pre-processes the hidden states before combining them after expert processing.
This method restores the hidden states to their original ordering before expert processing by using the communication managerβs restoration function.
- combine_all_to_all(
- hidden_states: torch.Tensor,
- async_finish: bool = True,
- allocate_on_comm_stream: bool = True,
Performs all-to-all communication to combine tokens after expert processing.
- combine_postprocess(hidden_states: torch.Tensor)#
Post-processes the combined hidden states after all-to-all communication.
This method reshapes the combined hidden states to match the original input shape.
- token_unpermutation(hidden_states: torch.Tensor) torch.Tensor #
Reverses the token permutation process to restore the original token order.
This method implements the token unpermutation process in three steps:
Pre-process the hidden states to restore their original ordering
Perform all-to-all communication to combine tokens
Post-process the combined tokens to match the original input shape