core.inference.contexts.routing_metadata#

Module Contents#

Classes#

RoutingMetadata

Manages routing indices metadata for MoE layers during inference.

API#

class core.inference.contexts.routing_metadata.RoutingMetadata(
context: megatron.core.inference.contexts.dynamic_context.DynamicInferenceContext,
moe_router_topk: int,
)#

Manages routing indices metadata for MoE layers during inference.

This class provides static buffers for CUDA graph compatibility when recording routing decisions. It holds a reference to the inference context to automatically determine whether to use static buffers based on CUDA graph state.

Parameters:
  • context (DynamicInferenceContext) – The inference context.

  • moe_router_topk (int) – Number of experts selected per token.

Initialization

_ensure_buffer_allocated() None#

Allocate the static buffer if not already allocated.

Gets the actual number of MoE layers from RouterReplay instances.

get_routing_indices() Optional[torch.Tensor]#

Get the recorded routing indices.

Automatically uses the static buffer when CUDA graphs are active, otherwise retrieves from RouterReplay utility.

Returns:

Tensor of shape [num_tokens, num_moe_layers, topk] or None if not available.

enable_static_buffer_recording() None#

Enable recording into the static buffer for CUDA graph compatibility.

This sets up RouterReplay instances to copy routing indices into our pre-allocated static buffer instead of creating new tensors. Allocates the buffer lazily on first call.

disable_static_buffer_recording() None#

Disable static buffer recording, reverting to normal tensor assignment.