core.inference.moe.metadata#

Fused NVLS metadata update kernel for MoE expert parallelism.

Replaces the multi-kernel sequence: dist.all_gather_into_tensor(…) # NCCL local_tokens_per_rank.sum() # kernel local_tokens_per_rank[:rank].sum() # kernel local_tokens_per_rank.max() # kernel step_metadata.copy(…) # kernel

with a single Triton kernel that: 1. Multicast-stores this rank’s local_tokens to the symmetric memory buffer. 2. Barrier (all ranks have written). 3. Reads all ranks’ counts, computes sum / prefix-sum / max. 4. Writes the 3-element step_metadata tensor in-place.

Module Contents#

Functions#

_fused_metadata_kernel

Fused allgather + reduce kernel for MoE step metadata.

fused_metadata_update

Fused NVLS allgather + reduce for MoE step metadata.

API#

core.inference.moe.metadata._fused_metadata_kernel(
local_tokens,
local_buf_ptr,
multicast_ptr,
signal_pad_ptrs,
step_metadata_ptr,
RANK: triton.language.constexpr,
WORLD_SIZE: triton.language.constexpr,
)#

Fused allgather + reduce kernel for MoE step metadata.

Single CTA. Writes this rank’s local_tokens to the symmetric buffer via multicast store, barriers, then reads all ranks’ values from the local buffer and computes [valid_tokens, rank_token_offset, ep_max_tokens].

Parameters:
  • local_tokens – scalar int32, this rank’s token count.

  • local_buf_ptr – pointer to the local symmetric memory buffer (for reads).

  • multicast_ptr – multicast pointer to the symmetric memory buffer (for writes).

  • signal_pad_ptrs – signal pads for barrier synchronization.

  • step_metadata_ptr – pointer to the 3-element int32 output tensor.

  • RANK – this rank’s index (constexpr).

  • WORLD_SIZE – total number of ranks (constexpr).

core.inference.moe.metadata.fused_metadata_update(
local_tokens: int,
local_buf: torch.Tensor,
symm_mem_hdl: torch._C._distributed_c10d._SymmetricMemory,
step_metadata: torch.Tensor,
) None#

Fused NVLS allgather + reduce for MoE step metadata.

Parameters:
  • local_tokens – number of tokens on this rank this step.

  • local_buf – the local symmetric memory buffer tensor ([WORLD_SIZE] int32). Used for reads after the barrier.

  • symm_mem_hdl – symmetric memory handle for the metadata buffer. Provides the multicast pointer for writes and signal pads for barrier.

  • step_metadata – [3] int32 CUDA tensor to write [valid_tokens, rank_token_offset, ep_max_tokens] into.