core.inference.moe.metadata#
Fused NVLS metadata update kernel for MoE expert parallelism.
Replaces the multi-kernel sequence: dist.all_gather_into_tensor(…) # NCCL local_tokens_per_rank.sum() # kernel local_tokens_per_rank[:rank].sum() # kernel local_tokens_per_rank.max() # kernel step_metadata.copy(…) # kernel
with a single Triton kernel that: 1. Multicast-stores this rank’s local_tokens to the symmetric memory buffer. 2. Barrier (all ranks have written). 3. Reads all ranks’ counts, computes sum / prefix-sum / max. 4. Writes the 3-element step_metadata tensor in-place.
Module Contents#
Functions#
Fused allgather + reduce kernel for MoE step metadata. |
|
Fused NVLS allgather + reduce for MoE step metadata. |
API#
- core.inference.moe.metadata._fused_metadata_kernel(
- local_tokens,
- local_buf_ptr,
- multicast_ptr,
- signal_pad_ptrs,
- step_metadata_ptr,
- RANK: triton.language.constexpr,
- WORLD_SIZE: triton.language.constexpr,
Fused allgather + reduce kernel for MoE step metadata.
Single CTA. Writes this rank’s local_tokens to the symmetric buffer via multicast store, barriers, then reads all ranks’ values from the local buffer and computes [valid_tokens, rank_token_offset, ep_max_tokens].
- Parameters:
local_tokens – scalar int32, this rank’s token count.
local_buf_ptr – pointer to the local symmetric memory buffer (for reads).
multicast_ptr – multicast pointer to the symmetric memory buffer (for writes).
signal_pad_ptrs – signal pads for barrier synchronization.
step_metadata_ptr – pointer to the 3-element int32 output tensor.
RANK – this rank’s index (constexpr).
WORLD_SIZE – total number of ranks (constexpr).
- core.inference.moe.metadata.fused_metadata_update(
- local_tokens: int,
- local_buf: torch.Tensor,
- symm_mem_hdl: torch._C._distributed_c10d._SymmetricMemory,
- step_metadata: torch.Tensor,
Fused NVLS allgather + reduce for MoE step metadata.
- Parameters:
local_tokens – number of tokens on this rank this step.
local_buf – the local symmetric memory buffer tensor ([WORLD_SIZE] int32). Used for reads after the barrier.
symm_mem_hdl – symmetric memory handle for the metadata buffer. Provides the multicast pointer for writes and signal pads for barrier.
step_metadata – [3] int32 CUDA tensor to write [valid_tokens, rank_token_offset, ep_max_tokens] into.