core.transformer.moe.token_dispatcher_inference#

CUDA-graph-compatible token dispatcher for inference.

This dispatcher is only used during CUDA-graphed inference iterations. It replaces AlltoAll with AllGather/ReduceScatter for token exchange, keeping all metadata GPU-resident to avoid host synchronizations that would break CUDA graph capture.

Supports latency-optimized NVLS collectives (multimem all-gather/reduce-scatter) on Hopper+ GPUs with BF16, with automatic fallback to NCCL.

Module Contents#

Classes#

InferenceCUDAGraphTokenDispatcher

CUDA-graph-compatible AllGather token dispatcher for inference.

API#

class core.transformer.moe.token_dispatcher_inference.InferenceCUDAGraphTokenDispatcher(
num_local_experts: int,
local_expert_indices: List[int],
config: megatron.core.transformer.transformer_config.TransformerConfig,
pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
)#

Bases: megatron.core.transformer.moe.token_dispatcher.MoEAllGatherTokenDispatcher

CUDA-graph-compatible AllGather token dispatcher for inference.

Only used during CUDA-graphed inference iterations. Swapped in by MoELayer.set_inference_cuda_graphed_iteration() before graph capture and swapped out by MoELayer.unset_inference_cuda_graphed_iteration() after.

Key features:

  • AllGather/ReduceScatter instead of AlltoAll for CUDA graph compatibility

  • GPU-resident metadata (no host synchronization)

  • NVLS collectives on Hopper+ with automatic NCCL fallback

Initialization

Initialize the InferenceCUDAGraphTokenDispatcher.

Parameters:
  • num_local_experts – Number of experts on this rank.

  • local_expert_indices – Global indices of experts on this rank.

  • config – Transformer configuration.

  • pg_collection – Process group collection for distributed ops.

_maybe_allocate_ag_buffers(
routing_map: torch.Tensor,
probs: torch.Tensor,
hidden_states: torch.Tensor,
) dict#

Allocate a single symmetric memory output buffer for fused all-gather.

Creates one contiguous symmetric memory buffer sized for the gathered (global) routing_map, probs, and hidden_states, then returns sliced views into it. This allows a single fused NVLS all-gather kernel to write all three outputs in one launch.

Parameters:
  • routing_map (torch.Tensor) – Local routing map, shape [local_tokens, topk]. Boolean or integer tensor mapping each token to its selected experts.

  • probs (torch.Tensor) – Local routing probabilities, shape [local_tokens, topk]. Normalized weights for each token’s selected experts.

  • hidden_states (torch.Tensor) – Local hidden states, shape [local_tokens, hidden_dim].

Returns:

A dictionary with the following keys: - “handle”: Symmetric memory handle for NVLS ops, or None if symmetric memory is unavailable. - “routing_map”: Raw byte view for the gathered routing map output. - “routing_map_offset”: Byte offset of routing_map within the buffer. - “probs”: Raw byte view for the gathered probs output. - “probs_offset”: Byte offset of probs within the buffer. - “hidden_states”: Raw byte view for the gathered hidden states output. - “hidden_states_offset”: Byte offset of hidden_states within the buffer. When allocation fails, all tensor views are None and offsets are 0.

Return type:

dict

_maybe_allocate_rs_buffer(x: torch.Tensor) dict#

Allocate a symmetric memory buffer for reduce-scatter input.

The buffer has the same shape and dtype as x so that x can be copied into it before the NVLS reduce-scatter kernel.

Parameters:

x (torch.Tensor) – The global hidden states to be reduce-scattered, shape [global_tokens, hidden_dim].

Returns:

A dictionary with keys “handle” (symmetric memory handle, or None if unavailable) and “tensor” (the allocated buffer, or None).

Return type:

dict

token_dispatch(hidden_states, probs)#

Gathers tokens from all EP ranks using AllGather.

Performs all-gather on routing_map (stored in self.routing_map), probs, and hidden_states so that every rank holds the full global view. Uses latency-optimized fused NVLS multimem_all_gather on Hopper+ GPUs with BF16 when symmetric memory is available. Falls back to NCCL otherwise.

Parameters:
  • hidden_states (torch.Tensor) – Local hidden states, shape [local_tokens, hidden_dim].

  • probs (torch.Tensor) – Local routing probabilities, shape [local_tokens, topk]. Normalized weights for each token’s selected experts.

Returns:

(hidden_states, probs) gathered across all EP ranks. - hidden_states (torch.Tensor): Shape [global_tokens, hidden_dim]. - probs (torch.Tensor): Shape [global_tokens, topk]. Also updates self.routing_map in-place to the gathered shape [global_tokens, topk].

Return type:

tuple

dispatch_postprocess(hidden_states, probs)#

Pass-through: returns inputs directly without permutation.

Unlike the training dispatcher, this does not permute tokens or compute tokens_per_expert. The downstream InferenceGroupedMLP (FlashInfer / CUTLASS fused MoE kernel) operates directly on the routing map stored in self.routing_map.

Parameters:
  • hidden_states (torch.Tensor) – Gathered hidden states, shape [global_tokens, hidden_dim].

  • probs (torch.Tensor) – Gathered routing probabilities, shape [global_tokens, topk].

Returns:

(hidden_states, tokens_per_expert, probs) where tokens_per_expert is always None.

Return type:

tuple

combine_preprocess(expert_output)#

Pass-through: InferenceGroupedMLP already produces unpermuted output.

No unpermutation is needed because dispatch_postprocess did not permute the tokens in the first place.

Parameters:

expert_output (torch.Tensor) – Output from InferenceGroupedMLP, shape [global_tokens, hidden_dim].

Returns:

The input tensor unchanged.

Return type:

torch.Tensor

token_combine(hidden_states)#

Combines expert outputs across EP ranks using Reduce-Scatter.

Reduces the global expert output (summing contributions from each rank) and scatters the result so each rank receives its local token slice. Uses latency-optimized NVLS multimem_reduce_scatter on Hopper+ GPUs with BF16 when symmetric memory is available. Falls back to NCCL otherwise.

Parameters:

hidden_states (torch.Tensor) – Combined expert output after routing weights have been applied, shape [global_tokens, hidden_dim].

Returns:

Local slice of the reduced output, shape [local_tokens, hidden_dim] where local_tokens = global_tokens // ep_size.

Return type:

torch.Tensor