core.inference.contexts.routing_metadata#
Module Contents#
Classes#
Manages routing indices metadata for MoE layers during inference. |
API#
- class core.inference.contexts.routing_metadata.RoutingMetadata(
- context: megatron.core.inference.contexts.dynamic_context.DynamicInferenceContext,
- moe_router_topk: int,
Manages routing indices metadata for MoE layers during inference.
This class provides static buffers for CUDA graph compatibility when recording routing decisions. It holds a reference to the inference context to automatically determine whether to use static buffers based on CUDA graph state.
- Parameters:
context (DynamicInferenceContext) – The inference context.
moe_router_topk (int) – Number of experts selected per token.
Initialization
- _ensure_buffer_allocated() None#
Allocate the static buffer if not already allocated.
Gets the actual number of MoE layers from RouterReplay instances.
- get_routing_indices() Optional[torch.Tensor]#
Get the recorded routing indices.
Automatically uses the static buffer when CUDA graphs are active, otherwise retrieves from RouterReplay utility.
- Returns:
Tensor of shape [num_tokens, num_moe_layers, topk] or None if not available.
- enable_static_buffer_recording() None#
Enable recording into the static buffer for CUDA graph compatibility.
This sets up RouterReplay instances to copy routing indices into our pre-allocated static buffer instead of creating new tensors. Allocates the buffer lazily on first call.
- disable_static_buffer_recording() None#
Disable static buffer recording, reverting to normal tensor assignment.