core.transformer.moe.ops.paged_stash#
Triton kernels for MoE paged stash.
Module Contents#
Functions#
Stash variable-length MoE activations into a paged buffer (CUDA, or pinned host). |
|
Restore variable-length MoE activations from a paged buffer (CUDA, or pinned host). |
Data#
API#
- core.transformer.moe.ops.paged_stash.GLOBAL_BLOCK_SIZE#
1024
- core.transformer.moe.ops.paged_stash.__all__#
[‘GLOBAL_BLOCK_SIZE’, ‘paged_stash_copy_kernel’, ‘paged_stash_pop_kernel’]
- core.transformer.moe.ops.paged_stash.paged_stash_copy_kernel(
- src_ptr,
- cuda_dst_ptr,
- host_dst_ptr,
- num_tokens_ptr,
- free_list_cuda_ptr,
- free_list_host_ptr,
- free_list_head_ptr,
- free_list_tail_ptr,
- free_list_capacity_ptr,
- page_record_ptr,
- overflow_ptr,
- host_spill_global_ptr,
- spilled_to_host_ptr,
- new_free_list_head_ptr,
- PAGE_SIZE: triton.language.constexpr,
- HIDDEN_SIZE: triton.language.constexpr,
- BLOCK_SIZE: triton.language.constexpr,
- HAS_HOST_BUFFER: triton.language.constexpr,
Stash variable-length MoE activations into a paged buffer (CUDA, or pinned host).
Uses a custom Triton kernel because the token count is only known at runtime and lives on device. Page allocation from the circular freelist, page_record metadata, and the activation copy are fused in one GPU launch to avoid host sync and keep stash CUDA-graph friendly. Fixed-size pages reduce fragmentation vs oversized static expert buffers.
Per launch (program 0 handles metadata; all programs run the copy): 1. If overflow is already set, restore freelist heads and return. 2. Compute pages needed from num_tokens. Try the CUDA freelist; if full, try the host freelist when available; otherwise set overflow and return. 3. Copy tokens in parallel: resolve page_id per token, record page_ids in page_record, write hidden vectors into the chosen CUDA or host pages. 4. Program 0 writes updated freelist heads for the caller to copy_ back.
- core.transformer.moe.ops.paged_stash.paged_stash_pop_kernel(
- cuda_src_ptr,
- host_src_ptr,
- dst_ptr,
- num_tokens_ptr,
- page_record_ptr,
- spilled_to_host_ptr,
- overflow_ptr,
- free_list_cuda_ptr,
- free_list_host_ptr,
- free_list_tail_ptr,
- free_list_capacity_ptr,
- new_free_list_tail_ptr,
- PAGE_SIZE: triton.language.constexpr,
- HIDDEN_SIZE: triton.language.constexpr,
- BLOCK_SIZE: triton.language.constexpr,
Restore variable-length MoE activations from a paged buffer (CUDA, or pinned host).
Inverse of paged_stash_copy_kernel. Uses a custom Triton kernel for the same reasons: runtime token count and stash metadata live on device, so the reload, page_record lookup, and freelist recycle must fuse on-GPU without host sync.
Per launch (program 0 handles metadata; all programs run the copy): 1. If overflow is already set, restore freelist tails and return. 2. Read spilled_to_host from the matching stash: CUDA buffer by default, host buffer when the forward stash spilled to pinned memory. 3. Copy tokens in parallel: look up page_id from page_record, read hidden vectors from the stash pages into dst, return each page_id to the freelist. 4. Program 0 writes updated freelist tails for the caller to copy_ back.