nemo_automodel.components.moe.megatron.token_dispatcher

View as Markdown

Module Contents

Classes

NameDescription
MoEFlexTokenDispatcherFlex token dispatcher supporting DeepEP, HybridEP, and UCCL-EP backends.
TokenDispatcherConfigConfiguration for MoE token dispatch and combine backends.
_DeepepManagerA manager class to handle fused all-to-all communication processes for MoE models using
_DispatchManagerA manager class to handle dispatch and combine processes for MoE models.
_HybridEPManagerA manager class to handle fused all-to-all communication processes for MoE models using

API

class nemo_automodel.components.moe.megatron.token_dispatcher.MoEFlexTokenDispatcher(
num_local_experts: int,
local_expert_indices: typing.List[int],
config: nemo_automodel.components.moe.megatron.token_dispatcher.TokenDispatcherConfig,
ep_group: torch.distributed.ProcessGroup
)

Flex token dispatcher supporting DeepEP, HybridEP, and UCCL-EP backends.

_comm_manager
= MoEFlexTokenDispatcher.shared_uccl_manager
ep_size
= ep_group.size()
shared_deepep_manager
_DeepepManager = None
shared_hybridep_manager
_HybridEPManager = None
shared_uccl_manager
_DeepepManager = None
tp_size
= 1
nemo_automodel.components.moe.megatron.token_dispatcher.MoEFlexTokenDispatcher._initialize_metadata(
num_local_tokens: int,
probs: torch.Tensor
) -> torch.Tensor

Initialize the routing map and probs to a unified format covering the TPxEP group. This design decouples the communication group from underlying model parallelism groups, such that the communication strategy of tokens can be agnostic of TP size and EP size.

nemo_automodel.components.moe.megatron.token_dispatcher.MoEFlexTokenDispatcher.combine_all_to_all(
hidden_states: torch.Tensor,
async_finish: bool = True,
allocate_on_comm_stream: bool = True
)

Performs all-to-all communication to combine tokens after expert processing.

nemo_automodel.components.moe.megatron.token_dispatcher.MoEFlexTokenDispatcher.combine_postprocess(
hidden_states: torch.Tensor
)

Post-processes the combined hidden states after all-to-all communication.

This method reshapes the combined hidden states to match the original input shape.

nemo_automodel.components.moe.megatron.token_dispatcher.MoEFlexTokenDispatcher.combine_preprocess(
hidden_states: torch.Tensor
)

Pre-processes the hidden states before combining them after expert processing.

This method restores the hidden states to their original ordering before expert processing by using the communication manager’s restoration function.

nemo_automodel.components.moe.megatron.token_dispatcher.MoEFlexTokenDispatcher.dispatch_all_to_all(
hidden_states: torch.Tensor,
probs: torch.Tensor = None,
async_finish: bool = True,
allocate_on_comm_stream: bool = True
)

Performs all-to-all communication to dispatch tokens across expert parallel ranks.

nemo_automodel.components.moe.megatron.token_dispatcher.MoEFlexTokenDispatcher.dispatch_postprocess(
hidden_states: torch.Tensor
)

Post-processes the dispatched hidden states after all-to-all communication.

This method retrieves the permuted hidden states by experts, calculates the number of tokens per expert, and returns the processed data ready for expert processing.

nemo_automodel.components.moe.megatron.token_dispatcher.MoEFlexTokenDispatcher.dispatch_preprocess(
hidden_states: torch.Tensor,
num_local_tokens: int,
probs: torch.Tensor
)

Preprocesses the hidden states and routing information before dispatching tokens to experts. Args: hidden_states (torch.Tensor): Input hidden states to be processed num_local_tokens (int): Number of tokens to be processed probs (torch.Tensor): Routing probabilities for each token-expert pair

Returns:

Tuple containing:

nemo_automodel.components.moe.megatron.token_dispatcher.MoEFlexTokenDispatcher.dispatch_preprocess2(
hidden_states: torch.Tensor,
num_local_tokens: int,
token_probs: torch.Tensor,
token_indices: torch.Tensor
)

Preprocesses the hidden states and routing information before dispatching tokens to experts.

For DeepEP backend: uses token_indices and token_probs directly. For HybridEP backend: converts token_indices to routing_map (multihot format).

nemo_automodel.components.moe.megatron.token_dispatcher.MoEFlexTokenDispatcher.token_permutation(
hidden_states: torch.Tensor,
num_local_tokens: int,
probs: torch.Tensor
) -> typing.Tuple[torch.Tensor, torch.Tensor, torch.Tensor]

Permutes tokens according to probs and dispatches them to experts.

This method implements the token permutation process in three steps:

  1. Preprocess the hidden states
  2. Perform all-to-all communication to dispatch tokens
  3. Post-process the dispatched tokens for expert processing
nemo_automodel.components.moe.megatron.token_dispatcher.MoEFlexTokenDispatcher.token_permutation2(
hidden_states: torch.Tensor,
num_local_tokens: int,
token_probs: torch.Tensor,
token_indices: torch.Tensor
) -> typing.Tuple[torch.Tensor, torch.Tensor, torch.Tensor]

Permutes tokens according to probs and dispatches them to experts.

This method implements the token permutation process in three steps:

  1. Preprocess the hidden states
  2. Perform all-to-all communication to dispatch tokens
  3. Post-process the dispatched tokens for expert processing
nemo_automodel.components.moe.megatron.token_dispatcher.MoEFlexTokenDispatcher.token_unpermutation(
hidden_states: torch.Tensor
) -> torch.Tensor

Reverses the token permutation process to restore the original token order.

This method implements the token unpermutation process in three steps:

  1. Pre-process the hidden states to restore their original ordering
  2. Perform all-to-all communication to combine tokens
  3. Post-process the combined tokens to match the original input shape
class nemo_automodel.components.moe.megatron.token_dispatcher.TokenDispatcherConfig(
moe_enable_deepep: bool = True,
moe_permute_fusion: bool = False,
moe_expert_capacity_factor: typing.Optional[float] = None,
moe_router_topk: int = 2,
moe_router_expert_pad_multiple: typing.Optional[int] = None,
num_moe_experts: int = 64,
moe_router_dtype: str = 'fp32',
moe_flex_dispatcher_backend: typing.Literal['deepep', 'hybridep', 'uccl_ep'] = 'deepep',
moe_deepep_num_sms: int = 20,
moe_hybridep_num_sms: int = 24,
moe_share_token_dispatcher: bool = True,
moe_deepep_async_dispatch: bool = False
)
Dataclass

Configuration for MoE token dispatch and combine backends.

moe_deepep_async_dispatch
bool = False

Use asynchronous DeepEP/UCCL-EP dispatch and allocate dispatched tensors on the communication stream.

moe_deepep_num_sms
int = 20

Number of SMs to use for DeepEP backend.

moe_enable_deepep
bool = True

Enable DeepEP for efficient token dispatching and combine in MoE models.

moe_expert_capacity_factor
Optional[float] = None

moe_expert_capacity_factor (float): The capacity factor for each expert, None means no token will be dropped. The default is None.

moe_flex_dispatcher_backend
Literal['deepep', 'hybridep', 'uccl_ep'] = 'deepep'

Backend for the flex token dispatcher. Options: ‘deepep’, ‘hybridep’, or ‘uccl_ep’.

moe_hybridep_num_sms
int = 24

Number of SMs to use for HybridEP dispatch and combine APIs.

moe_permute_fusion
bool = False

Fuse token rearrangement ops during token dispatching.

moe_router_dtype
str = 'fp32'

Data type for routing and expert output weighted averaging. Using fp32 or fp64 can improve stability especially when the number of experts is large (e.g. finegrained-moe). None means no changes for dtype.

moe_router_expert_pad_multiple
Optional[int] = None

Number of tokens to pad to a multiple of for each expert.

moe_router_topk
int = 2

Number of experts to route to for each token.

moe_share_token_dispatcher
bool = True

Share one communication manager instance across MoE layers for the configured backend.

num_moe_experts
int = 64

Number of experts to use for MoE layer. When set, it replaces MLP with MoE layer. Set to None for no MoE.

class nemo_automodel.components.moe.megatron.token_dispatcher._DeepepManager(
group: torch.distributed.ProcessGroup,
router_topk: int,
permute_fusion: bool = False,
capacity_factor: typing.Optional[float] = None,
num_experts: typing.Optional[int] = None,
num_local_experts: typing.Optional[int] = None,
router_dtype: typing.Optional[str] = None,
moe_router_expert_pad_multiple: typing.Optional[int] = None,
_dispatch_fn = None,
_combine_fn = None
)

Bases: _DispatchManager

A manager class to handle fused all-to-all communication processes for MoE models using DeepEP backend. See https://github.com/deepseek-ai/deepep for more details.

The workflow of the DeepEP dispatcher is: (1) setup_metadata(): Process routing map and probabilities to prepare dispatch metadata (2) dispatch():

  • Use fused kernel to permute tokens and perform all-to-all communication in single step (3) get_permuted_hidden_states_by_instances():
  • Convert routing map and probabilities to multihot format
  • Permute tokens using fused kernel (4) get_restored_hidden_states_by_instances():
  • Reverse permutation using fused kernel (5) combine():
  • Reverse process using fused kernel to unpermute and perform all-to-all in single step

This implementation uses fused communication kernels (fused_dispatch/fused_combine) that combine permutation and communication operations for improved efficiency compared to separate permute+alltoall steps.

_fused_combine
_fused_dispatch
token_indices
Optional[Tensor] = None
token_probs
Optional[Tensor] = None
nemo_automodel.components.moe.megatron.token_dispatcher._DeepepManager._indices_to_multihot(
indices,
probs
)

Converts a tensor of indices to a multihot vector.

Parameters:

indices
torch.Tensor

[num_tokens, topk] token indices, where -1 means masked out.

probs
torch.Tensor

[num_tokens, topk] token probabilities.

Returns:

Tuple[torch.Tensor, torch.Tensor]:

  • routing_map: Multihot vector.
  • probs: Multihot probabilities.
nemo_automodel.components.moe.megatron.token_dispatcher._DeepepManager.combine(
hidden_states: torch.Tensor,
async_finish: bool = False,
allocate_on_comm_stream: bool = False
) -> torch.Tensor

Reverse process using fused kernel to unpermute and perform all-to-all in single step

nemo_automodel.components.moe.megatron.token_dispatcher._DeepepManager.dispatch(
hidden_states: torch.Tensor,
async_finish: bool = False,
allocate_on_comm_stream: bool = False
) -> torch.Tensor

Dispatch the hidden_states

nemo_automodel.components.moe.megatron.token_dispatcher._DeepepManager.get_dispatched_metadata() -> torch.Tensor
nemo_automodel.components.moe.megatron.token_dispatcher._DeepepManager.get_number_of_tokens_per_expert() -> torch.Tensor

Get the number of tokens per expert.

nemo_automodel.components.moe.megatron.token_dispatcher._DeepepManager.get_permuted_hidden_states_by_experts(
hidden_states: torch.Tensor
) -> torch.Tensor
  • Convert routing map and probabilities to multihot format
  • Permute tokens using fused kernel
nemo_automodel.components.moe.megatron.token_dispatcher._DeepepManager.get_restored_hidden_states_by_experts(
hidden_states: torch.Tensor
) -> torch.Tensor

Restore the hidden states to their original ordering before expert processing

nemo_automodel.components.moe.megatron.token_dispatcher._DeepepManager.setup_metadata(
num_local_tokens: int,
probs: torch.Tensor
)

Process routing map and probabilities to prepare dispatch metadata

class nemo_automodel.components.moe.megatron.token_dispatcher._DispatchManager()
Abstract

A manager class to handle dispatch and combine processes for MoE models.

DispatcherManager handles token dispatching according to the routing_map of format [num_local_tokens, world_size, num_instances]. The routing_map is a 3D tensor where each element indicates whether a token should be sent to a specific rank.

num_instances is the maximum number of tokens instances dispatched into a target rank, it can be the number of local experts, or the size of sub_group.

nemo_automodel.components.moe.megatron.token_dispatcher._DispatchManager.combine(
hidden_states: torch.Tensor
) -> torch.Tensor
abstract

Combine the hidden_states after expert processing.

nemo_automodel.components.moe.megatron.token_dispatcher._DispatchManager.dispatch(
hidden_states: torch.Tensor
) -> torch.Tensor
abstract

Dispatch the hidden_states according to the routing_map.

nemo_automodel.components.moe.megatron.token_dispatcher._DispatchManager.get_dispatched_metadata() -> torch.Tensor
abstract

Get the metadata of the dispatched hidden_states.

nemo_automodel.components.moe.megatron.token_dispatcher._DispatchManager.get_permuted_hidden_states_by_experts(
hidden_states: torch.Tensor
) -> torch.Tensor
abstract

Get the permuted hidden states by instances.

nemo_automodel.components.moe.megatron.token_dispatcher._DispatchManager.get_restored_hidden_states_by_experts(
hidden_states: torch.Tensor
) -> torch.Tensor
abstract

Get the restored hidden states by instances.

nemo_automodel.components.moe.megatron.token_dispatcher._DispatchManager.setup_metadata(
routing_map: torch.Tensor,
probs: torch.Tensor
)
abstract

Set up metadata of routing_map and probs.

class nemo_automodel.components.moe.megatron.token_dispatcher._HybridEPManager(
group: torch.distributed.ProcessGroup,
num_local_experts: int,
num_experts: int,
router_topk: int,
permute_fusion: bool = False,
moe_hybridep_num_sms: int = 24
)

Bases: _DispatchManager

A manager class to handle fused all-to-all communication processes for MoE models using HybridEP backend. See https://github.com/deepseek-ai/DeepEP/tree/hybrid-ep for more details.

The workflow of the HybridEP dispatcher is: (1) setup_metadata(): Process routing map and probabilities to prepare dispatch metadata (2) dispatch():

  • Permute tokens for communication, perform all-to-all communication, and permute tokens for experts in single step (3) combine():
  • Unpermute tokens for communication, perform all-to-all communication, and unpermute tokens for attention in single step
routing_map
Optional[Tensor] = None
token_probs
Optional[Tensor] = None
nemo_automodel.components.moe.megatron.token_dispatcher._HybridEPManager._indices_to_multihot(
indices: torch.Tensor,
probs: torch.Tensor
)

Converts a tensor of indices to a multihot vector.

nemo_automodel.components.moe.megatron.token_dispatcher._HybridEPManager.combine(
hidden_states: torch.Tensor,
async_finish: bool = True,
allocate_on_comm_stream: bool = True
) -> torch.Tensor
nemo_automodel.components.moe.megatron.token_dispatcher._HybridEPManager.dispatch(
hidden_states: torch.Tensor,
async_finish: bool = True,
allocate_on_comm_stream: bool = True
) -> torch.Tensor
nemo_automodel.components.moe.megatron.token_dispatcher._HybridEPManager.get_dispatched_metadata() -> torch.Tensor
nemo_automodel.components.moe.megatron.token_dispatcher._HybridEPManager.get_number_of_tokens_per_expert() -> torch.Tensor
nemo_automodel.components.moe.megatron.token_dispatcher._HybridEPManager.get_permuted_hidden_states_by_experts(
hidden_states: torch.Tensor
) -> torch.Tensor
nemo_automodel.components.moe.megatron.token_dispatcher._HybridEPManager.get_restored_hidden_states_by_experts(
hidden_states: torch.Tensor
) -> torch.Tensor
nemo_automodel.components.moe.megatron.token_dispatcher._HybridEPManager.setup_metadata(
routing_map: torch.Tensor,
probs: torch.Tensor
)

Process routing map and probabilities to prepare dispatch metadata.

nemo_automodel.components.moe.megatron.token_dispatcher._HybridEPManager.setup_metadata_from_indices(
token_indices: torch.Tensor,
token_probs: torch.Tensor
)

Convert from topk indices format to multihot routing_map format.