core.transformer.moe.token_dispatcher_inference#

Inference token dispatchers for MoE expert parallelism.

Two dispatchers are provided, selected via config.inference_moe_token_dispatcher_type:

NCCLAllGatherDispatcher (‘nccl’, default) Standard NCCL AllGather/ReduceScatter. All EP ranks must contribute the same token count per step; decode-only CUDA graphs are forced automatically.

NVLSAllGatherVDispatcher (‘nvls’) Variable-count NVLS AllGather-V/ReduceScatter-V via multimem kernels. Supports different token counts per rank per step. Requires Hopper+ GPUs with NVLink and symmetric memory. Opt-in.

InferenceAllGatherDispatcherBase is a minimal base used solely for isinstance checks and to hold _valid_tokens_tensor — the shared interface that mcore_fused_moe reads to gate kernel work to the valid token prefix. Each dispatcher defines its own update_metadata method, invoked from the first instance’s token_dispatch so the per-step metadata kernel is captured inside the CUDA graph.

Module Contents#

Classes#

InferenceAllGatherDispatcherBase

Minimal base for inference AllGather token dispatchers.

NCCLAllGatherDispatcher

AllGather token dispatcher for inference using NCCL.

NVLSAllGatherVDispatcher

Variable-count AllGather-V / ReduceScatter-V dispatcher for inference CUDA graphs.

API#

class core.transformer.moe.token_dispatcher_inference.InferenceAllGatherDispatcherBase(
*args,
runs_metadata_sync: bool = True,
**kwargs,
)#

Bases: megatron.core.transformer.moe.token_dispatcher.MoEAllGatherTokenDispatcher

Minimal base for inference AllGather token dispatchers.

Exists for isinstance checks and to expose _valid_tokens_tensor — the single class-level value that mcore_fused_moe reads (via experts.py) to gate kernel work to the valid token prefix. Each concrete subclass owns its own metadata and defines update_metadata independently.

Initialization

Initialize the AllGather based token dispatcher.

Parameters:
  • num_local_experts (int) – Number of local experts.

  • local_expert_indices (List[int]) – Indices of local experts.

  • config (TransformerConfig) – Configuration for the MoE layer.

  • pg_collection (ProcessGroupCollection, optional) – Process groups for MoE operations.

_valid_tokens_tensor: Optional[torch.Tensor]#

None

_host_valid_tokens_estimate: Optional[int]#

None

classmethod _valid_tokens() torch.Tensor#
classmethod _get_host_valid_tokens_estimate() Optional[int]#
classmethod allocate_valid_tokens_tensor() None#

Allocate the per-step valid-tokens scalar shared across all dispatcher subclasses.

Called at model init from the dynamic context to ensure the buffer receives a valid pointer. Must run outside CUDA graph capture so the stable address is available during replay.

abstractmethod update_metadata(local_tokens: int) None#

Per-step metadata refresh fired from the first instance’s token_dispatch.

Must be idempotent across a step (only called once) and safe to capture into a CUDA graph on the decode path.

class core.transformer.moe.token_dispatcher_inference.NCCLAllGatherDispatcher(
num_local_experts: int,
local_expert_indices: List[int],
config: megatron.core.transformer.transformer_config.TransformerConfig,
pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
runs_metadata_sync: bool = True,
)#

Bases: core.transformer.moe.token_dispatcher_inference.InferenceAllGatherDispatcherBase

AllGather token dispatcher for inference using NCCL.

Two modes, selected by _use_allgather_v (set from the context each step):

CG path (use_allgather_v=False): all EP ranks contribute the same token count, guaranteed by decode-only CUDA graphs. Standard AllGather/ReduceScatter.

Non-CG path (use_allgather_v=True): ranks may have different token counts (prefill). Each rank pads its tensors to max_tokens, runs a standard AllGather, then compacts by stripping per-rank padding. Combine is the reverse: expand compact output to padded layout, ReduceScatter, truncate to local token count.

Initialization

Initialize the AllGather based token dispatcher.

Parameters:
  • num_local_experts (int) – Number of local experts.

  • local_expert_indices (List[int]) – Indices of local experts.

  • config (TransformerConfig) – Configuration for the MoE layer.

  • pg_collection (ProcessGroupCollection, optional) – Process groups for MoE operations.

_use_allgather_v: bool#

False

_local_tokens_per_rank: Optional[List[int]]#

None

classmethod allocate_buffers() None#

Allocate the per-step valid-tokens tensor read by mcore_fused_moe.

Called once at model init from the dynamic context. Must run outside any CUDA graph capture so update_metadata can write to a stable address during replay without triggering allocations inside the graph.

update_metadata(local_tokens: int) None#

Per-step metadata update; invoked from the first instance’s token_dispatch.

CG path (_use_allgather_v=False): ranks have equal counts by construction, so we only refresh valid_tokens_tensor — a single .fill that is safe to capture.

Non-CG path (_use_allgather_v=True): ranks may differ, so we all-gather the per-rank counts and host-sync via .tolist() for the pad/compact logic below. This path never runs under graph capture.

token_dispatch(hidden_states, probs)#

Gather hidden_states, probs, and routing_map from all EP ranks.

CG path: standard AllGather (equal token counts guaranteed). Non-CG path: pad to max_tokens, AllGather, compact (strip per-rank padding).

Parameters:
  • hidden_states – [local_tokens, hidden_dim] local input.

  • probs – [local_tokens, topk] local routing probabilities.

Returns:

(hidden_states, probs) gathered to [total_tokens, *] shape. Also updates self.routing_map to [total_tokens, topk].

dispatch_postprocess(hidden_states, probs)#

Pass-through: mcore_fused_moe operates directly on the gathered tensors.

combine_preprocess(expert_output)#

Pass-through: unpermute is handled inside mcore_fused_moe.

token_combine(hidden_states)#

Scatter-reduce expert outputs back to each EP rank.

CG path: standard ReduceScatter (equal token counts guaranteed). Non-CG path: expand compact output to padded layout, ReduceScatter, truncate.

Parameters:

hidden_states – [total_tokens, hidden_dim] expert outputs.

Returns:

[local_tokens, hidden_dim] bf16 local token outputs.

class core.transformer.moe.token_dispatcher_inference.NVLSAllGatherVDispatcher(
num_local_experts: int,
local_expert_indices: List[int],
config: megatron.core.transformer.transformer_config.TransformerConfig,
pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
runs_metadata_sync: bool = True,
)#

Bases: core.transformer.moe.token_dispatcher_inference.InferenceAllGatherDispatcherBase

Variable-count AllGather-V / ReduceScatter-V dispatcher for inference CUDA graphs.

Replaces the fixed AllGather/ReduceScatter of NCCLAllGatherDispatcher with variable-count NVLS collectives so ranks can hold different token counts per step. All metadata lives on-device; no host sync is needed between steps.

Requires Hopper+ GPUs with NVLink and symmetric memory.

Initialization

Initialize the AllGather based token dispatcher.

Parameters:
  • num_local_experts (int) – Number of local experts.

  • local_expert_indices (List[int]) – Indices of local experts.

  • config (TransformerConfig) – Configuration for the MoE layer.

  • pg_collection (ProcessGroupCollection, optional) – Process groups for MoE operations.

_step_metadata: Optional[torch.Tensor]#

None

_per_rank_worst_case_token_count: int#

2048

_symm_agv_hidden: Optional[dict]#

None

_symm_agv_routing: Optional[dict]#

None

_symm_agv_probs: Optional[dict]#

None

_symm_rsv: Optional[dict]#

None

classmethod _get_rsv_tensor() Optional[torch.Tensor]#

Return the RSV symmetric buffer tensor so mcore_fused_moe can write unpermute output directly into it, avoiding a copy before RSV.

classmethod _rank_token_offset() torch.Tensor#
classmethod _ep_max_tokens() torch.Tensor#
classmethod _delete_buffers()#
classmethod allocate_buffers(
per_rank_worst_case_token_count: int,
topk: int,
hidden_size: int,
ep_group: torch.distributed.ProcessGroup,
) None#

Allocate all symmetric buffers and initialize class-level metadata.

Called once at model init. Allocates fixed-size AGV and RSV symmetric memory buffers so dispatch/combine can proceed without any allocation on the hot path.

Parameters:
  • per_rank_worst_case_token_count – Max tokens this rank can contribute, computed by the context as round_up_tokens(max_tokens) // tp_size.

  • topk – MoE router top-k value.

  • hidden_size – Model hidden dimension.

  • ep_group – Expert parallel process group.

update_metadata(local_tokens: int) None#

Per-step metadata update; invoked from the first instance’s token_dispatch.

Fires the fused NVLS allgather+reduce to publish [valid_tokens, rank_token_offset, ep_max_tokens] into _step_metadata, then (for FlashInfer) pre-masks the routing buffer with -1 so rows beyond valid_tokens are ignored by the GEMM; the AGV below overwrites [0, valid_tokens) in-place.

dispatch_preprocess(hidden_states, routing_map, probs)#

Store routing map and local token count; no inter-rank communication.

If shared_expert_overlap is enabled (set_shared_experts has been called) AND _external_shared_expert_launch is False, launch the entire shared- expert forward on SharedExpertMLP.stream so it runs concurrently with AGV dispatch, expert GEMMs, and RSV combine.

When _external_shared_expert_launch is True (latent-MoE inference path), the layer launches the shared expert before its fc1_latent_proj on the full hidden_states; the dispatcher does not launch it here.

token_dispatch(hidden_states, probs)#

AllGather-V: gather hidden_states, probs, and routing_map from all EP ranks.

Parameters:
  • hidden_states – [local_tokens, hidden_size] bf16 local input.

  • probs – [local_tokens, topk] fp32 local routing probabilities.

Returns:

(hidden_states, probs) gathered to [global_max, *] shape. Also updates self.routing_map to [global_max, topk] int64.

dispatch_postprocess(hidden_states, probs)#

Pass-through: mcore_fused_moe operates directly on the gathered tensors.

combine_preprocess(expert_output)#

Pass-through: unpermute is handled inside mcore_fused_moe.

token_combine(hidden_states)#

ReduceScatter-V: sum expert outputs across EP ranks, scatter to local tokens.

Parameters:

hidden_states – [global_max, hidden_size] expert outputs (fp32 when written directly to the RSV buffer, bf16 otherwise).

Returns:

[local_tokens, hidden_size] bf16 local token outputs.

combine_postprocess(hidden_states)#

Restore original input shape (e.g. [S/TP, B, H] from [S*B/TP, H]).

If shared_expert_overlap is enabled AND _external_shared_expert_launch is False, join SharedExpertMLP.stream and add the shared-expert output produced concurrently during dispatch+combine.

When _external_shared_expert_launch is True (latent-MoE inference path), the join+add happens in the layer’s postprocess after fc2_latent_proj, so we only restore the shape here.