core.transformer.moe.token_dispatcher_inference#
CUDA-graph-compatible token dispatcher for inference.
This dispatcher is only used during CUDA-graphed inference iterations. It replaces AlltoAll with AllGather/ReduceScatter for token exchange, keeping all metadata GPU-resident to avoid host synchronizations that would break CUDA graph capture.
Supports latency-optimized NVLS collectives (multimem all-gather/reduce-scatter) on Hopper+ GPUs with BF16, with automatic fallback to NCCL.
Module Contents#
Classes#
CUDA-graph-compatible AllGather token dispatcher for inference. |
API#
- class core.transformer.moe.token_dispatcher_inference.InferenceCUDAGraphTokenDispatcher(
- num_local_experts: int,
- local_expert_indices: List[int],
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
Bases:
megatron.core.transformer.moe.token_dispatcher.MoEAllGatherTokenDispatcherCUDA-graph-compatible AllGather token dispatcher for inference.
Only used during CUDA-graphed inference iterations. Swapped in by MoELayer.set_inference_cuda_graphed_iteration() before graph capture and swapped out by MoELayer.unset_inference_cuda_graphed_iteration() after.
Key features:
AllGather/ReduceScatter instead of AlltoAll for CUDA graph compatibility
GPU-resident metadata (no host synchronization)
NVLS collectives on Hopper+ with automatic NCCL fallback
Initialization
Initialize the InferenceCUDAGraphTokenDispatcher.
- Parameters:
num_local_experts – Number of experts on this rank.
local_expert_indices – Global indices of experts on this rank.
config – Transformer configuration.
pg_collection – Process group collection for distributed ops.
- _maybe_allocate_ag_buffers(
- routing_map: torch.Tensor,
- probs: torch.Tensor,
- hidden_states: torch.Tensor,
Allocate a single symmetric memory output buffer for fused all-gather.
Creates one contiguous symmetric memory buffer sized for the gathered (global) routing_map, probs, and hidden_states, then returns sliced views into it. This allows a single fused NVLS all-gather kernel to write all three outputs in one launch.
- Parameters:
routing_map (torch.Tensor) – Local routing map, shape [local_tokens, topk]. Boolean or integer tensor mapping each token to its selected experts.
probs (torch.Tensor) – Local routing probabilities, shape [local_tokens, topk]. Normalized weights for each token’s selected experts.
hidden_states (torch.Tensor) – Local hidden states, shape [local_tokens, hidden_dim].
- Returns:
A dictionary with the following keys: - “handle”: Symmetric memory handle for NVLS ops, or None if symmetric memory is unavailable. - “routing_map”: Raw byte view for the gathered routing map output. - “routing_map_offset”: Byte offset of routing_map within the buffer. - “probs”: Raw byte view for the gathered probs output. - “probs_offset”: Byte offset of probs within the buffer. - “hidden_states”: Raw byte view for the gathered hidden states output. - “hidden_states_offset”: Byte offset of hidden_states within the buffer. When allocation fails, all tensor views are None and offsets are 0.
- Return type:
dict
- _maybe_allocate_rs_buffer(x: torch.Tensor) dict#
Allocate a symmetric memory buffer for reduce-scatter input.
The buffer has the same shape and dtype as x so that x can be copied into it before the NVLS reduce-scatter kernel.
- Parameters:
x (torch.Tensor) – The global hidden states to be reduce-scattered, shape [global_tokens, hidden_dim].
- Returns:
A dictionary with keys “handle” (symmetric memory handle, or None if unavailable) and “tensor” (the allocated buffer, or None).
- Return type:
dict
- token_dispatch(hidden_states, probs)#
Gathers tokens from all EP ranks using AllGather.
Performs all-gather on routing_map (stored in self.routing_map), probs, and hidden_states so that every rank holds the full global view. Uses latency-optimized fused NVLS multimem_all_gather on Hopper+ GPUs with BF16 when symmetric memory is available. Falls back to NCCL otherwise.
- Parameters:
hidden_states (torch.Tensor) – Local hidden states, shape [local_tokens, hidden_dim].
probs (torch.Tensor) – Local routing probabilities, shape [local_tokens, topk]. Normalized weights for each token’s selected experts.
- Returns:
(hidden_states, probs) gathered across all EP ranks. - hidden_states (torch.Tensor): Shape [global_tokens, hidden_dim]. - probs (torch.Tensor): Shape [global_tokens, topk]. Also updates self.routing_map in-place to the gathered shape [global_tokens, topk].
- Return type:
tuple
- dispatch_postprocess(hidden_states, probs)#
Pass-through: returns inputs directly without permutation.
Unlike the training dispatcher, this does not permute tokens or compute tokens_per_expert. The downstream InferenceGroupedMLP (FlashInfer / CUTLASS fused MoE kernel) operates directly on the routing map stored in self.routing_map.
- Parameters:
hidden_states (torch.Tensor) – Gathered hidden states, shape [global_tokens, hidden_dim].
probs (torch.Tensor) – Gathered routing probabilities, shape [global_tokens, topk].
- Returns:
(hidden_states, tokens_per_expert, probs) where tokens_per_expert is always None.
- Return type:
tuple
- combine_preprocess(expert_output)#
Pass-through: InferenceGroupedMLP already produces unpermuted output.
No unpermutation is needed because dispatch_postprocess did not permute the tokens in the first place.
- Parameters:
expert_output (torch.Tensor) – Output from InferenceGroupedMLP, shape [global_tokens, hidden_dim].
- Returns:
The input tensor unchanged.
- Return type:
torch.Tensor
- token_combine(hidden_states)#
Combines expert outputs across EP ranks using Reduce-Scatter.
Reduces the global expert output (summing contributions from each rank) and scatters the result so each rank receives its local token slice. Uses latency-optimized NVLS multimem_reduce_scatter on Hopper+ GPUs with BF16 when symmetric memory is available. Falls back to NCCL otherwise.
- Parameters:
hidden_states (torch.Tensor) – Combined expert output after routing weights have been applied, shape [global_tokens, hidden_dim].
- Returns:
Local slice of the reduced output, shape [local_tokens, hidden_dim] where local_tokens = global_tokens // ep_size.
- Return type:
torch.Tensor