core.transformer.moe.token_dispatcher_inference#
Inference token dispatchers for MoE expert parallelism.
Two dispatchers are provided, selected via config.inference_moe_token_dispatcher_type:
NCCLAllGatherDispatcher (‘nccl’, default) Standard NCCL AllGather/ReduceScatter. All EP ranks must contribute the same token count per step; decode-only CUDA graphs are forced automatically.
NVLSAllGatherVDispatcher (‘nvls’) Variable-count NVLS AllGather-V/ReduceScatter-V via multimem kernels. Supports different token counts per rank per step. Requires Hopper+ GPUs with NVLink and symmetric memory. Opt-in.
InferenceAllGatherDispatcherBase is a minimal base used solely for isinstance checks and to hold _valid_tokens_tensor — the shared interface that mcore_fused_moe reads to gate kernel work to the valid token prefix. Each dispatcher defines its own update_metadata method, invoked from the first instance’s token_dispatch so the per-step metadata kernel is captured inside the CUDA graph.
Module Contents#
Classes#
Minimal base for inference AllGather token dispatchers. |
|
AllGather token dispatcher for inference using NCCL. |
|
Variable-count AllGather-V / ReduceScatter-V dispatcher for inference CUDA graphs. |
API#
- class core.transformer.moe.token_dispatcher_inference.InferenceAllGatherDispatcherBase(
- *args,
- runs_metadata_sync: bool = True,
- **kwargs,
Bases:
megatron.core.transformer.moe.token_dispatcher.MoEAllGatherTokenDispatcherMinimal base for inference AllGather token dispatchers.
Exists for isinstance checks and to expose _valid_tokens_tensor — the single class-level value that mcore_fused_moe reads (via experts.py) to gate kernel work to the valid token prefix. Each concrete subclass owns its own metadata and defines update_metadata independently.
Initialization
Initialize the AllGather based token dispatcher.
- Parameters:
num_local_experts (int) – Number of local experts.
local_expert_indices (List[int]) – Indices of local experts.
config (TransformerConfig) – Configuration for the MoE layer.
pg_collection (ProcessGroupCollection, optional) – Process groups for MoE operations.
- _valid_tokens_tensor: Optional[torch.Tensor]#
None
- _host_valid_tokens_estimate: Optional[int]#
None
- classmethod _valid_tokens() torch.Tensor#
- classmethod _get_host_valid_tokens_estimate() Optional[int]#
- classmethod allocate_valid_tokens_tensor() None#
Allocate the per-step valid-tokens scalar shared across all dispatcher subclasses.
Called at model init from the dynamic context to ensure the buffer receives a valid pointer. Must run outside CUDA graph capture so the stable address is available during replay.
- abstractmethod update_metadata(local_tokens: int) None#
Per-step metadata refresh fired from the first instance’s token_dispatch.
Must be idempotent across a step (only called once) and safe to capture into a CUDA graph on the decode path.
- class core.transformer.moe.token_dispatcher_inference.NCCLAllGatherDispatcher(
- num_local_experts: int,
- local_expert_indices: List[int],
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
- runs_metadata_sync: bool = True,
Bases:
core.transformer.moe.token_dispatcher_inference.InferenceAllGatherDispatcherBaseAllGather token dispatcher for inference using NCCL.
Two modes, selected by _use_allgather_v (set from the context each step):
CG path (use_allgather_v=False): all EP ranks contribute the same token count, guaranteed by decode-only CUDA graphs. Standard AllGather/ReduceScatter.
Non-CG path (use_allgather_v=True): ranks may have different token counts (prefill). Each rank pads its tensors to max_tokens, runs a standard AllGather, then compacts by stripping per-rank padding. Combine is the reverse: expand compact output to padded layout, ReduceScatter, truncate to local token count.
Initialization
Initialize the AllGather based token dispatcher.
- Parameters:
num_local_experts (int) – Number of local experts.
local_expert_indices (List[int]) – Indices of local experts.
config (TransformerConfig) – Configuration for the MoE layer.
pg_collection (ProcessGroupCollection, optional) – Process groups for MoE operations.
- _use_allgather_v: bool#
False
- _local_tokens_per_rank: Optional[List[int]]#
None
- classmethod allocate_buffers() None#
Allocate the per-step valid-tokens tensor read by mcore_fused_moe.
Called once at model init from the dynamic context. Must run outside any CUDA graph capture so update_metadata can write to a stable address during replay without triggering allocations inside the graph.
- update_metadata(local_tokens: int) None#
Per-step metadata update; invoked from the first instance’s token_dispatch.
CG path (_use_allgather_v=False): ranks have equal counts by construction, so we only refresh valid_tokens_tensor — a single .fill that is safe to capture.
Non-CG path (_use_allgather_v=True): ranks may differ, so we all-gather the per-rank counts and host-sync via .tolist() for the pad/compact logic below. This path never runs under graph capture.
- token_dispatch(hidden_states, probs)#
Gather hidden_states, probs, and routing_map from all EP ranks.
CG path: standard AllGather (equal token counts guaranteed). Non-CG path: pad to max_tokens, AllGather, compact (strip per-rank padding).
- Parameters:
hidden_states – [local_tokens, hidden_dim] local input.
probs – [local_tokens, topk] local routing probabilities.
- Returns:
(hidden_states, probs) gathered to [total_tokens, *] shape. Also updates self.routing_map to [total_tokens, topk].
- dispatch_postprocess(hidden_states, probs)#
Pass-through: mcore_fused_moe operates directly on the gathered tensors.
- combine_preprocess(expert_output)#
Pass-through: unpermute is handled inside mcore_fused_moe.
- token_combine(hidden_states)#
Scatter-reduce expert outputs back to each EP rank.
CG path: standard ReduceScatter (equal token counts guaranteed). Non-CG path: expand compact output to padded layout, ReduceScatter, truncate.
- Parameters:
hidden_states – [total_tokens, hidden_dim] expert outputs.
- Returns:
[local_tokens, hidden_dim] bf16 local token outputs.
- class core.transformer.moe.token_dispatcher_inference.NVLSAllGatherVDispatcher(
- num_local_experts: int,
- local_expert_indices: List[int],
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
- runs_metadata_sync: bool = True,
Bases:
core.transformer.moe.token_dispatcher_inference.InferenceAllGatherDispatcherBaseVariable-count AllGather-V / ReduceScatter-V dispatcher for inference CUDA graphs.
Replaces the fixed AllGather/ReduceScatter of NCCLAllGatherDispatcher with variable-count NVLS collectives so ranks can hold different token counts per step. All metadata lives on-device; no host sync is needed between steps.
Requires Hopper+ GPUs with NVLink and symmetric memory.
Initialization
Initialize the AllGather based token dispatcher.
- Parameters:
num_local_experts (int) – Number of local experts.
local_expert_indices (List[int]) – Indices of local experts.
config (TransformerConfig) – Configuration for the MoE layer.
pg_collection (ProcessGroupCollection, optional) – Process groups for MoE operations.
- _step_metadata: Optional[torch.Tensor]#
None
- _per_rank_worst_case_token_count: int#
2048
None
- _symm_agv_routing: Optional[dict]#
None
- _symm_agv_probs: Optional[dict]#
None
- _symm_rsv: Optional[dict]#
None
- classmethod _get_rsv_tensor() Optional[torch.Tensor]#
Return the RSV symmetric buffer tensor so mcore_fused_moe can write unpermute output directly into it, avoiding a copy before RSV.
- classmethod _rank_token_offset() torch.Tensor#
- classmethod _ep_max_tokens() torch.Tensor#
- classmethod _delete_buffers()#
- classmethod allocate_buffers(
- per_rank_worst_case_token_count: int,
- topk: int,
- hidden_size: int,
- ep_group: torch.distributed.ProcessGroup,
Allocate all symmetric buffers and initialize class-level metadata.
Called once at model init. Allocates fixed-size AGV and RSV symmetric memory buffers so dispatch/combine can proceed without any allocation on the hot path.
- Parameters:
per_rank_worst_case_token_count – Max tokens this rank can contribute, computed by the context as round_up_tokens(max_tokens) // tp_size.
topk – MoE router top-k value.
hidden_size – Model hidden dimension.
ep_group – Expert parallel process group.
- update_metadata(local_tokens: int) None#
Per-step metadata update; invoked from the first instance’s token_dispatch.
Fires the fused NVLS allgather+reduce to publish [valid_tokens, rank_token_offset, ep_max_tokens] into _step_metadata, then (for FlashInfer) pre-masks the routing buffer with -1 so rows beyond valid_tokens are ignored by the GEMM; the AGV below overwrites [0, valid_tokens) in-place.
- dispatch_preprocess(hidden_states, routing_map, probs)#
Store routing map and local token count; no inter-rank communication.
If shared_expert_overlap is enabled (set_shared_experts has been called) AND _external_shared_expert_launch is False, launch the entire shared- expert forward on SharedExpertMLP.stream so it runs concurrently with AGV dispatch, expert GEMMs, and RSV combine.
When _external_shared_expert_launch is True (latent-MoE inference path), the layer launches the shared expert before its fc1_latent_proj on the full hidden_states; the dispatcher does not launch it here.
- token_dispatch(hidden_states, probs)#
AllGather-V: gather hidden_states, probs, and routing_map from all EP ranks.
- Parameters:
hidden_states – [local_tokens, hidden_size] bf16 local input.
probs – [local_tokens, topk] fp32 local routing probabilities.
- Returns:
(hidden_states, probs) gathered to [global_max, *] shape. Also updates self.routing_map to [global_max, topk] int64.
- dispatch_postprocess(hidden_states, probs)#
Pass-through: mcore_fused_moe operates directly on the gathered tensors.
- combine_preprocess(expert_output)#
Pass-through: unpermute is handled inside mcore_fused_moe.
- token_combine(hidden_states)#
ReduceScatter-V: sum expert outputs across EP ranks, scatter to local tokens.
- Parameters:
hidden_states – [global_max, hidden_size] expert outputs (fp32 when written directly to the RSV buffer, bf16 otherwise).
- Returns:
[local_tokens, hidden_size] bf16 local token outputs.
- combine_postprocess(hidden_states)#
Restore original input shape (e.g. [S/TP, B, H] from [S*B/TP, H]).
If shared_expert_overlap is enabled AND _external_shared_expert_launch is False, join SharedExpertMLP.stream and add the shared-expert output produced concurrently during dispatch+combine.
When _external_shared_expert_launch is True (latent-MoE inference path), the join+add happens in the layer’s postprocess after fc2_latent_proj, so we only restore the shape here.