core.inference.config#

Module Contents#

Classes#

MambaInferenceStateConfig

Config for initializing Mamba model inference state tensors.

PrefixCachingEvictionPolicy

Eviction policy for prefix caching blocks.

PrefixCachingCoordinatorPolicy

Routing policy for the DP inference coordinator with prefix caching.

KVCacheManagementMode

Mode for handling large tensors (KV cache, Mamba states) during suspend/resume.

CudaGraphSizingDistribution

How CUDA graph token-count sizes are spaced when generating the captured graphs.

InferenceConfig

Config for inference.

API#

class core.inference.config.MambaInferenceStateConfig#

Config for initializing Mamba model inference state tensors.

Note that we maintain separate metadata for decode, regular prefill, and chunked prefill requests because the Mamba kernels do not yet support mixing these. Once the kernels have been updated we can simplify this code.

layer_type_list: List[str]#

None

A list of strings that indicates the layer type (Mamba / Attention / MLP) for each layer. See megatron/core/models/hybrid/hybrid_layer_allocation.py for the list of symbols.

conv_states_shape: Tuple[int]#

None

Mamba conv states shape per request.

ssm_states_shape: Tuple[int]#

None

Mamba SSM states shape per request.

conv_states_dtype: torch.dtype#

None

The dtype to use for the Mamba conv state tensor. Defaults to the model dtype.

ssm_states_dtype: torch.dtype#

None

The dtype to use for the Mamba SSM state tensor. Defaults to the model dtype.

mamba_chunk_size: int#

128

The chunk size used by the Mamba SSM Triton kernels.

classmethod from_model(
model: megatron.core.transformer.module.MegatronModule,
conv_states_dtype: Optional[torch.dtype] = None,
ssm_states_dtype: Optional[torch.dtype] = None,
) Optional[core.inference.config.MambaInferenceStateConfig]#

Returns Mamba inference state config from the model if it is a hybrid model.

class core.inference.config.PrefixCachingEvictionPolicy#

Bases: str, enum.Enum

Eviction policy for prefix caching blocks.

Only applies when enable_prefix_caching is True.

Initialization

Initialize self. See help(type(self)) for accurate signature.

REF_ZERO#

‘ref_zero’

Deregister blocks immediately when ref_count hits 0. No caching after release.

LRU#

‘lru’

Keep released blocks in hash table. Evict oldest ref=0 blocks when space is needed.

class core.inference.config.PrefixCachingCoordinatorPolicy#

Bases: str, enum.Enum

Routing policy for the DP inference coordinator with prefix caching.

Initialization

Initialize self. See help(type(self)) for accurate signature.

LONGEST_PREFIX#

‘longest_prefix’

Route to the rank with the longest consecutive prefix match.

FIRST_PREFIX_BLOCK#

‘first_prefix_block’

Route to the rank that has the first block hash cached. O(ranks) check.

ROUND_ROBIN#

‘round_robin’

Route requests to ranks in round-robin order, ignoring prefix affinity.

class core.inference.config.KVCacheManagementMode#

Bases: str, enum.Enum

Mode for handling large tensors (KV cache, Mamba states) during suspend/resume.

Initialization

Initialize self. See help(type(self)) for accurate signature.

PERSIST#

‘persist’

Do not deallocate and reallocate large tensors; keep them on GPU.

OFFLOAD#

‘offload’

Offload large tensors to CPU during deallocation; onload during allocation.

RECOMPUTE#

‘recompute’

Deallocate large tensors and recompute them from scratch during allocation.

class core.inference.config.CudaGraphSizingDistribution#

Bases: str, enum.Enum

How CUDA graph token-count sizes are spaced when generating the captured graphs.

EXPONENTIAL (default) — token counts halve from cuda_graph_max_tokens down to tp_size, giving a log-spaced distribution. Bounded relative padding (~2x worst case) at every scale and log2(max_tokens) total graphs.

LINEAR — Include size-1 and size-2 graphs where applicable, linear spacing up until 256, and sparser linear spacing past 256. e.g. [1, 2, 4] + range(8, 256, 8) + range(256, max+1, 16). Higher graph density at the top end.

Initialization

Initialize self. See help(type(self)) for accurate signature.

EXPONENTIAL#

‘exponential’

LINEAR#

‘linear’

class core.inference.config.InferenceConfig#

Config for inference.

NOTE: Must remain mutually exclusive with the TransformerConfig.

block_size_tokens: int#

256

Size of KV cache block size.

buffer_size_gb: int#

20

Buffer size reserved on the GPU for the KV cache. If unified_memory_level >= 1, then CPU memory is additionally utilized, resulting in a total buffer size of buffer_size_gb + paused_buffer_size_gb.

paused_buffer_size_gb: Optional[int]#

None

Portion of buffer reserved for paused requests. Active requests are paused when there are not enough active blocks available to continue generating a request. The total buffer size (active + paused) depends on unified_memory_level (uvm): - uvm 0: buffer_size_gb (paused buffer is inclusive) - uvm 1: buffer_size_gb + paused_buffer_size_gb

mamba_inference_state_config: Optional[core.inference.config.MambaInferenceStateConfig]#

None

The Mamba inference state config if the model is a hybrid model.

mamba_memory_ratio: Optional[float]#

None

Percentage of memory buffer to allocate for Mamba states. If not specified, allocates Mamba state tensors for each KV cache block. Only used for hybrid models.

max_requests: Optional[int]#

None

Max number of active requests to use for decode-only forward passes. This is primarily limited by the combination of buffer_size_gb and max_sequence_length.

max_tokens: Optional[int]#

None

Max number of tokens to use for forward passes. This is primarily limited by prefill activation memory usage. (Defaults to 16384).

unified_memory_level: int#

0

Sets unified memory usage within the dynamic inference context. The levels are: 0) no unified memory (default) 1) allocate memory_buffer in unified memory. Eventually, additional levels will be included to control other tensors within the context.

kv_cache_management_mode: core.inference.config.KVCacheManagementMode#

None

Mode used to determine how large tensors are handled by the allocate and deallocate methods. See KVCacheManagementMode for options.

num_cuda_graphs: Optional[int]#

None

Maximum number of cuda graphs to capture. Graph token counts are spaced from 1 up to a per-graph-type budget:

  • Decode-only graphs are always bounded by max_requests * (num_speculative_tokens + 1).

  • Prefill/mixed graphs share that same bound by default, or extend up to max_tokens when cuda_graph_all_prefills is set. Due to rounding, the actual number of cuda graphs may not equal this argument.

cuda_graph_mixed_prefill_count: Optional[int]#

16

The number of mixed prefill graphs to capture if mixed prefill/decode graphs are enabled.

cuda_graph_sizing_distribution: core.inference.config.CudaGraphSizingDistribution#

None

How CUDA graph token counts are spaced. EXPONENTIAL (default) halves from cuda_graph_max_tokens down to tp_size (log-spaced, ~log2(max_tokens) graphs). LINEAR uses a range of linear strides (includes small graphs + mid-range linearity + a bigger step size at the top end).

use_cuda_graphs_for_non_decode_steps: bool#

True

Whether to use CUDA graphs for non-decode steps.

cuda_graph_all_prefills: bool#

False

Whether prefill/mixed CUDA graphs should span up to max_tokens. When False (default), prefill/mixed graphs are bounded by the same token limit as decode graphs: max_requests * (num_speculative_tokens + 1). When True, prefill/mixed graph capture is extended to cover the full max_tokens budget.

static_kv_memory_pointers: bool#

False

Whether the KV cache (and Mamba states) will reside at the same memory addresses after suspend/resume as before. When True, CUDA graphs that reference these buffers remain valid across suspend/resume cycles and do not need to be recaptured. Requires either UVM or torch_memory_saver when kv_cache_management_mode is not PERSIST.

max_sequence_length: int#

2560

Max possible sequence length (prompt + output) that will occur.

pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection]#

None

A ProcessGroupCollection for distributed execution.

use_flashinfer_fused_rope: Optional[bool]#

False

If True, use flashinfer’s fused rope implementation. If None, defaults to using flash-infer if available.

materialize_only_last_token_logits: bool#

True

Whether to only materialize logits for the last token. This should be set to False if returning log probs.

enable_chunked_prefill: bool#

False

Whether to enable chunked prefill.

num_speculative_tokens: int#

0

The number of speculative tokens to generate for decode steps.

enable_prefix_caching: bool#

False

Whether to enable prefix caching for KV cache block sharing.

prefix_caching_eviction_policy: core.inference.config.PrefixCachingEvictionPolicy#

None

Eviction policy for prefix caching blocks. See PrefixCachingEvictionPolicy for options.

Only applies when enable_prefix_caching is True.

prefix_caching_coordinator_policy: core.inference.config.PrefixCachingCoordinatorPolicy#

None

Routing policy for the DP inference coordinator. See PrefixCachingCoordinatorPolicy for options.

Only applies when enable_prefix_caching is True and using a coordinator.

prefix_caching_routing_alpha: float#

0.5

Weight for prefix-aware scoring: score = alpha * match + (1 - alpha) * normalized_load. Higher alpha favors prefix cache hits; lower alpha favors load balance. Must be in [0, 1]. Only applies when enable_prefix_caching is True and using a coordinator.

prefix_caching_mamba_gb: Optional[float]#

None

GPU memory budget (in GB) for the Mamba state cache used by prefix caching on hybrid models. Each cache slot stores SSM and conv states for all Mamba layers at a single block boundary. When set, Mamba states at KV divergence and last-aligned block boundaries are cached and reused across requests with matching prefixes.

track_paused_request_events: bool#

False

Whether to track paused request events. If True, add_event_pause() is called on requests when they are paused during bookkeeping.

track_generated_token_events: bool#

False

Whether to track per-token events with timestamps for each generated token. When enabled, each generated token creates a GENERATED_TOKEN event with a timestamp, useful for per-token latency analysis.

metrics_writer: Optional[WandbModule]#

None

Wandb module for writing metrics.

logging_step_interval: int#

0

The step interval at which to log inference metrics to wandb. Defaults to 0, which means no logging.

sampling_backend: Literal[torch, flashinfer]#

‘torch’

Which sampling kernels to use during inference.

request_metadata_types: Optional[List[Tuple[str, torch.dtype]]]#

None

A list of the per-request metadata types to track. Each entry is a tuple consisting of the string label and the target dtype.

use_synchronous_zmq_collectives: bool#

False

Whether to use synchronous ZMQ collectives for inference. If True, the all_reduce_max operation will be performed synchronously, which can help reduce performance variability for MoEs.

disable_ep_consensus: bool#

False

If True, the engine skips the EP-group consensus all-reduce in run_engine_with_coordinator and decides whether to step based on local state alone. The rank still calls controller.dummy_forward() whenever local_pending == 0, so EP collectives (NCCL all-to-all, etc.) stay in sync — without this, a peer running a real forward would deadlock waiting on this rank’s all-to-all participation. Trades off the consensus all-reduce CPU cost for unconditional dummy_forwards on idle ranks.

ep_consensus_interval: int#

20

How many steps to skip between EP-consensus all-reduces when the engine has pending work. Consensus is always run immediately when there is no global work (to detect new arrivals quickly); this interval only applies to the busy case, where skipping avoids per-step all-reduce overhead. In the worst case, pausing is delayed by this many steps (~10–20 ms per step at typical decode throughput).

verbose: dataclasses.InitVar[bool]#

False

Whether to log detailed context configuration at initialization. This is an InitVar and is not stored as a field on the config.

__post_init__(verbose: bool)#