nemo_automodel.components.moe.megatron.token_dispatcher#

Module Contents#

Classes#

_DispatchManager

A manager class to handle dispatch and combine processes for MoE models.

_DeepepManager

A manager class to handle fused all-to-all communication processes for MoE models using DeepEP backend. See https://github.com/deepseek-ai/deepep for more details.

MoEConfig

MoEFlexTokenDispatcher

Flex token dispatcher using DeepEP.

Data#

SHARING_DEEPEP_MANAGER

We use the following notation throughout this file: H: hidden size B: micro batch size S: sequence length TP: tensor model parallel size EP: expert model parallel size num_local_tokens: S/TPB num_global_tokens: num_local_tokensTP*EP

API#

nemo_automodel.components.moe.megatron.token_dispatcher.SHARING_DEEPEP_MANAGER#

True

We use the following notation throughout this file: H: hidden size B: micro batch size S: sequence length TP: tensor model parallel size EP: expert model parallel size num_local_tokens: S/TPB num_global_tokens: num_local_tokensTP*EP

class nemo_automodel.components.moe.megatron.token_dispatcher._DispatchManager#

Bases: abc.ABC

A manager class to handle dispatch and combine processes for MoE models.

DispatcherManager handles token dispatching according to the routing_map of format [num_local_tokens, world_size, num_instances]. The routing_map is a 3D tensor where each element indicates whether a token should be sent to a specific rank.

num_instances is the maximum number of tokens instances dispatched into a target rank, it can be the number of local experts, or the size of sub_group.

abstractmethod setup_metadata(routing_map: torch.Tensor, probs: torch.Tensor)#

Set up metadata of routing_map and probs.

abstractmethod dispatch(hidden_states: torch.Tensor) torch.Tensor#

Dispatch the hidden_states according to the routing_map.

abstractmethod combine(hidden_states: torch.Tensor) torch.Tensor#

Combine the hidden_states after expert processing.

abstractmethod get_dispached_metadata() torch.Tensor#

Get the metadata of the dispatched hidden_states.

abstractmethod get_permuted_hidden_states_by_experts(
hidden_states: torch.Tensor,
) torch.Tensor#

Get the permuted hidden states by instances.

abstractmethod get_restored_hidden_states_by_experts(
hidden_states: torch.Tensor,
) torch.Tensor#

Get the restored hidden states by instances.

class nemo_automodel.components.moe.megatron.token_dispatcher._DeepepManager(
group: torch.distributed.ProcessGroup,
router_topk: int,
permute_fusion: bool = False,
capacity_factor: Optional[float] = None,
num_experts: Optional[int] = None,
num_local_experts: Optional[int] = None,
router_dtype: Optional[str] = None,
moe_router_expert_pad_multiple: Optional[int] = None,
)#

Bases: nemo_automodel.components.moe.megatron.token_dispatcher._DispatchManager

A manager class to handle fused all-to-all communication processes for MoE models using DeepEP backend. See https://github.com/deepseek-ai/deepep for more details.

The workflow of the DeepEP dispatcher is: (1) setup_metadata(): Process routing map and probabilities to prepare dispatch metadata (2) dispatch(): - Use fused kernel to permute tokens and perform all-to-all communication in single step (3) get_permuted_hidden_states_by_instances(): - Convert routing map and probabilities to multihot format - Permute tokens using fused kernel (4) get_restored_hidden_states_by_instances(): - Reverse permutation using fused kernel (5) combine(): - Reverse process using fused kernel to unpermute and perform all-to-all in single step

This implementation uses fused communication kernels (fused_dispatch/fused_combine) that combine permutation and communication operations for improved efficiency compared to separate permute+alltoall steps.

Initialization

setup_metadata(num_local_tokens: int, probs: torch.Tensor)#

Process routing map and probabilities to prepare dispatch metadata

dispatch(
hidden_states: torch.Tensor,
async_finish: bool = False,
allocate_on_comm_stream: bool = False,
) torch.Tensor#

Dispatch the hidden_states

_indices_to_multihot(indices, probs)#

Converts a tensor of indices to a multihot vector.

Parameters:
  • indices (torch.Tensor) – [num_tokens, topk] token indices, where -1 means masked out.

  • probs (torch.Tensor) – [num_tokens, topk] token probabilities.

Returns:

  • routing_map: Multihot vector.

  • probs: Multihot probabilities.

Return type:

Tuple[torch.Tensor, torch.Tensor]

get_dispached_metadata() torch.Tensor#
get_number_of_tokens_per_expert() torch.Tensor#

Get the number of tokens per expert.

combine(
hidden_states: torch.Tensor,
async_finish: bool = False,
allocate_on_comm_stream: bool = False,
) torch.Tensor#

Reverse process using fused kernel to unpermute and perform all-to-all in single step

get_permuted_hidden_states_by_experts(
hidden_states: torch.Tensor,
) torch.Tensor#
  • Convert routing map and probabilities to multihot format

  • Permute tokens using fused kernel

get_restored_hidden_states_by_experts(
hidden_states: torch.Tensor,
) torch.Tensor#

Restore the hidden states to their original ordering before expert processing

class nemo_automodel.components.moe.megatron.token_dispatcher.MoEConfig#
moe_enable_deepep: bool#

True

Enable DeepEP for efficient token dispatching and combine in MoE models.

moe_permute_fusion: bool#

False

Fuse token rearrangement ops during token dispatching.

moe_expert_capacity_factor: Optional[float]#

None

moe_expert_capacity_factor (float): The capacity factor for each expert, None means no token will be dropped. The default is None.

moe_router_topk: int#

2

Number of experts to route to for each token.

moe_router_expert_pad_multiple: Optional[int]#

None

Number of tokens to pad to a multiple of for each expert.

num_moe_experts: int#

64

Number of experts to use for MoE layer. When set, it replaces MLP with MoE layer. Set to None for no MoE.

moe_router_dtype: str#

β€˜fp32’

Data type for routing and expert output weighted averaging. Using fp32 or fp64 can improve stability especially when the number of experts is large (e.g. finegrained-moe). None means no changes for dtype.

class nemo_automodel.components.moe.megatron.token_dispatcher.MoEFlexTokenDispatcher(
num_local_experts: int,
local_expert_indices: List[int],
config: nemo_automodel.components.moe.megatron.token_dispatcher.MoEConfig,
ep_group: torch.distributed.ProcessGroup,
)#

Flex token dispatcher using DeepEP.

Initialization

Initialize the Flex token dispatcher.

Parameters:
  • num_local_experts (int) – Number of local experts on the current device.

  • local_expert_indices (List[int]) – Indices of local experts on the current device.

  • config (MoEConfig) – Configuration for the transformer model.

  • group (torch.distributed.ProcessGroup) – Process group for MoE operations.

shared_comm_manager: nemo_automodel.components.moe.megatron.token_dispatcher._DeepepManager#

None

_initialize_metadata(
num_local_tokens: int,
probs: torch.Tensor,
) torch.Tensor#

Initialize the routing map and probs to a unified format covering the TPxEP group. This design decouples the communication group from underlying model parallelism groups, such that the communication strategy of tokens can be agnostic of TP size and EP size.

dispatch_preprocess2(
hidden_states: torch.Tensor,
num_local_tokens: int,
token_probs: torch.Tensor,
token_indices: torch.Tensor,
)#

Preprocesses the hidden states and routing information before dispatching tokens to experts.

dispatch_preprocess(
hidden_states: torch.Tensor,
num_local_tokens: int,
probs: torch.Tensor,
)#

Preprocesses the hidden states and routing information before dispatching tokens to experts.

Parameters:
  • hidden_states (torch.Tensor) – Input hidden states to be processed

  • num_local_tokens (int) – Number of tokens to be processed

  • probs (torch.Tensor) – Routing probabilities for each token-expert pair

Returns:

  • torch.Tensor: Reshaped hidden states

  • torch.Tensor: Token probabilities from the communication manager

  • None: Placeholder for compatibility

Return type:

Tuple containing

dispatch_all_to_all(
hidden_states: torch.Tensor,
probs: torch.Tensor = None,
async_finish: bool = True,
allocate_on_comm_stream: bool = True,
)#

Performs all-to-all communication to dispatch tokens across expert parallel ranks.

dispatch_postprocess(hidden_states: torch.Tensor)#

Post-processes the dispatched hidden states after all-to-all communication.

This method retrieves the permuted hidden states by experts, calculates the number of tokens per expert, and returns the processed data ready for expert processing.

token_permutation(
hidden_states: torch.Tensor,
num_local_tokens: int,
probs: torch.Tensor,
) Tuple[torch.Tensor, torch.Tensor, torch.Tensor]#

Permutes tokens according to probs and dispatches them to experts.

This method implements the token permutation process in three steps:

  1. Preprocess the hidden states

  2. Perform all-to-all communication to dispatch tokens

  3. Post-process the dispatched tokens for expert processing

token_permutation2(
hidden_states: torch.Tensor,
num_local_tokens: int,
token_probs: torch.Tensor,
token_indices: torch.Tensor,
) Tuple[torch.Tensor, torch.Tensor, torch.Tensor]#

Permutes tokens according to probs and dispatches them to experts.

This method implements the token permutation process in three steps:

  1. Preprocess the hidden states

  2. Perform all-to-all communication to dispatch tokens

  3. Post-process the dispatched tokens for expert processing

combine_preprocess(hidden_states: torch.Tensor)#

Pre-processes the hidden states before combining them after expert processing.

This method restores the hidden states to their original ordering before expert processing by using the communication manager’s restoration function.

combine_all_to_all(
hidden_states: torch.Tensor,
async_finish: bool = True,
allocate_on_comm_stream: bool = True,
)#

Performs all-to-all communication to combine tokens after expert processing.

combine_postprocess(hidden_states: torch.Tensor)#

Post-processes the combined hidden states after all-to-all communication.

This method reshapes the combined hidden states to match the original input shape.

token_unpermutation(hidden_states: torch.Tensor) torch.Tensor#

Reverses the token permutation process to restore the original token order.

This method implements the token unpermutation process in three steps:

  1. Pre-process the hidden states to restore their original ordering

  2. Perform all-to-all communication to combine tokens

  3. Post-process the combined tokens to match the original input shape