core.transformer.moe.ops.paged_stash#

Triton kernels for MoE paged stash.

Module Contents#

Functions#

paged_stash_copy_kernel

Stash variable-length MoE activations into a paged buffer (CUDA, or pinned host).

paged_stash_pop_kernel

Restore variable-length MoE activations from a paged buffer (CUDA, or pinned host).

Data#

API#

core.transformer.moe.ops.paged_stash.GLOBAL_BLOCK_SIZE#

1024

core.transformer.moe.ops.paged_stash.__all__#

[‘GLOBAL_BLOCK_SIZE’, ‘paged_stash_copy_kernel’, ‘paged_stash_pop_kernel’]

core.transformer.moe.ops.paged_stash.paged_stash_copy_kernel(
src_ptr,
cuda_dst_ptr,
host_dst_ptr,
num_tokens_ptr,
free_list_cuda_ptr,
free_list_host_ptr,
free_list_head_ptr,
free_list_tail_ptr,
free_list_capacity_ptr,
page_record_ptr,
overflow_ptr,
host_spill_global_ptr,
spilled_to_host_ptr,
new_free_list_head_ptr,
PAGE_SIZE: triton.language.constexpr,
HIDDEN_SIZE: triton.language.constexpr,
BLOCK_SIZE: triton.language.constexpr,
HAS_HOST_BUFFER: triton.language.constexpr,
)#

Stash variable-length MoE activations into a paged buffer (CUDA, or pinned host).

Uses a custom Triton kernel because the token count is only known at runtime and lives on device. Page allocation from the circular freelist, page_record metadata, and the activation copy are fused in one GPU launch to avoid host sync and keep stash CUDA-graph friendly. Fixed-size pages reduce fragmentation vs oversized static expert buffers.

Per launch (program 0 handles metadata; all programs run the copy): 1. If overflow is already set, restore freelist heads and return. 2. Compute pages needed from num_tokens. Try the CUDA freelist; if full, try the host freelist when available; otherwise set overflow and return. 3. Copy tokens in parallel: resolve page_id per token, record page_ids in page_record, write hidden vectors into the chosen CUDA or host pages. 4. Program 0 writes updated freelist heads for the caller to copy_ back.

core.transformer.moe.ops.paged_stash.paged_stash_pop_kernel(
cuda_src_ptr,
host_src_ptr,
dst_ptr,
num_tokens_ptr,
page_record_ptr,
spilled_to_host_ptr,
overflow_ptr,
free_list_cuda_ptr,
free_list_host_ptr,
free_list_tail_ptr,
free_list_capacity_ptr,
new_free_list_tail_ptr,
PAGE_SIZE: triton.language.constexpr,
HIDDEN_SIZE: triton.language.constexpr,
BLOCK_SIZE: triton.language.constexpr,
)#

Restore variable-length MoE activations from a paged buffer (CUDA, or pinned host).

Inverse of paged_stash_copy_kernel. Uses a custom Triton kernel for the same reasons: runtime token count and stash metadata live on device, so the reload, page_record lookup, and freelist recycle must fuse on-GPU without host sync.

Per launch (program 0 handles metadata; all programs run the copy): 1. If overflow is already set, restore freelist tails and return. 2. Read spilled_to_host from the matching stash: CUDA buffer by default, host buffer when the forward stash spilled to pinned memory. 3. Copy tokens in parallel: look up page_id from page_record, read hidden vectors from the stash pages into dst, return each page_id to the freelist. 4. Program 0 writes updated freelist tails for the caller to copy_ back.