core.inference.contexts.dynamic_context#

Module Contents#

Classes#

ContextErrorFactory

Factory class for serializing/deserializing context errors.

DynamicInferenceContext

Inference context that is passed to the main model in order to efficiently calculate and store the KV cache during inference.

Functions#

get_mem_size_str

Convert number of bytes to human-readable string.

API#

exception core.inference.contexts.dynamic_context.ContextOverflowError(
request_id: Optional[int],
message: Optional[str] = None,
*,
is_transient: bool = True,
)#

Bases: Exception

Base exception for when a new request does not fit.

Parameters:

is_transient (bool) – Flag marking whether error is transient (i.e., may work if we try again, but fails due to the current context state), or permanent (i.e., request will never fit in this context).

Initialization

Initialize self. See help(type(self)) for accurate signature.

exception core.inference.contexts.dynamic_context.RequestOverflowError(
request_id: Optional[int],
message: Optional[str] = None,
*,
is_transient: bool = True,
)#

Bases: core.inference.contexts.dynamic_context.ContextOverflowError

Adding request would overflow max request count.

Initialization

Initialize self. See help(type(self)) for accurate signature.

exception core.inference.contexts.dynamic_context.TokenOverflowError(
request_id: Optional[int],
message: Optional[str] = None,
*,
is_transient: bool = True,
)#

Bases: core.inference.contexts.dynamic_context.ContextOverflowError

Adding request would overflow max token count.

Initialization

Initialize self. See help(type(self)) for accurate signature.

exception core.inference.contexts.dynamic_context.MaxSequenceLengthOverflowError(
request_id,
message: Optional[str] = None,
)#

Bases: core.inference.contexts.dynamic_context.ContextOverflowError

Adding request would overflow max sequence length.

Initialization

Initialize self. See help(type(self)) for accurate signature.

exception core.inference.contexts.dynamic_context.BlockOverflowError(
request_id: Optional[int],
message: Optional[str] = None,
*,
is_transient: bool = True,
)#

Bases: core.inference.contexts.dynamic_context.ContextOverflowError

Adding request would overflow available memory blocks.

Initialization

Initialize self. See help(type(self)) for accurate signature.

exception core.inference.contexts.dynamic_context.ActiveRequestCountOverflowError(
max_request_count,
active_request_count,
)#

Bases: core.inference.contexts.dynamic_context.ContextOverflowError

Used when initialize_attention_state() is called with `num_warmup_requests > max_active_requests.

Initialization

Initialize self. See help(type(self)) for accurate signature.

exception core.inference.contexts.dynamic_context.TensorStateDeallocatedError(
request_id: Optional[int],
message: Optional[str] = None,
*,
is_transient: bool = True,
)#

Bases: core.inference.contexts.dynamic_context.ContextOverflowError

Context’s tensor state is currently deallocated, such as when the engine has been suspended.

Initialization

Initialize self. See help(type(self)) for accurate signature.

class core.inference.contexts.dynamic_context.ContextErrorFactory#

Factory class for serializing/deserializing context errors.

classmethod serialize(
error: core.inference.contexts.dynamic_context.ContextOverflowError,
) dict#

Serialize error.

Parameters:

error (ContextOverflowError) – Error.

Returns:

(dict) Serialized error data.

classmethod deserialize(
obj: dict,
) core.inference.contexts.dynamic_context.ContextOverflowError#

Deserialize error.

Parameters:

obj (dict) – Serialized error data.

Returns:

(ContextOverflowError) Deserialized error.

core.inference.contexts.dynamic_context.get_mem_size_str(n_bytes: int) str#

Convert number of bytes to human-readable string.

class core.inference.contexts.dynamic_context.DynamicInferenceContext(
*,
params_dtype: torch.dtype,
num_layers: int,
kv_channels: int,
num_attention_heads: int,
max_sequence_length: int,
buffer_size_gb: float,
max_requests: int = None,
max_tokens: int = DEFAULT_MAX_TOKENS,
block_size_tokens: int = 256,
tensor_model_parallel_size: Optional[int] = None,
cache_mla_latent: bool = False,
kv_lora_rank: Optional[int] = None,
qk_pos_emb_head_dim: Optional[int] = None,
num_cuda_graphs: Optional[int] = None,
materialize_only_last_token_logits: Optional[bool] = True,
mamba_inference_state_config: Optional[core.inference.contexts.attention_context.mamba_metadata.MambaInferenceStateConfig] = None,
use_cuda_graphs_for_non_decode_steps: bool = True,
use_flashinfer_fused_rope: bool = False,
unified_memory_level: Optional[int] = 1,
cuda_graph_max_tokens: Optional[int] = None,
cuda_graph_mixed_prefill_count: Optional[int] = 16,
metrics_writer: Optional[wandb] = None,
request_metadata_types: Optional[List[Tuple[str, torch.dtype, bool]]] = None,
)#

Bases: core.inference.contexts.base_context.BaseInferenceContext

Inference context that is passed to the main model in order to efficiently calculate and store the KV cache during inference.

The dynamic inference context manages both: 1) in-flight batching, and 2) a memory buffer for the block-level KV cache. For in-flight batching, requests of arbitrary sequence length may be added, paused, or removed from the context at any step. The only constraint is the maximum number of requests or tokens that the context is defined to support. For the block-level KV cache, a memory buffer is allocated up front (size buffer_size_gb if unified_memory_level == 0, or 2 * buffer_size_gb if unified_memory_level == 1), that is divided into blocks and dynamically assigned to requests. At any given step, any unassigned blocks equate to unused space.

Parameters:
  • params_dtype (torch.dtype) – Dtype used for KV cache.

  • num_layers (int) – Number of layers on this pipeline parallel rank.

  • kv_channels (int) – Hidden dimension per attention head.

  • num_attention_heads (int) – Number of attention heads.

  • max_sequence_length (int) – Max possible sequence length (prompt + output) that will occur.

  • buffer_size_gb (float) – Buffer size reserved on the GPU for the KV cache. if unified_memory_level >= 1, then CPU memory is additionally utilized, resulting in a total buffer size of 2 * buffer_size_gb. Regardless of total buffer size, the KV cache is conceptually divided into 50% active requests and 50% paused requests.

  • max_requests (int) – Max number of active requests to use for decode-only forward passes. This value is primarily limited by the combination of buffer_size_gb and max_sequence_length.

  • max_tokens (int) – Max number of tokens to use for forward passes. This is primarily limited by prefill activation memory usage. (Defaults to 16384).

  • block_size_tokens (int) – Size of KV cache block size.

  • tensor_model_parallel_size (Optional[int]) – Tensor model parallel size.

  • num_cuda_graphs (Optional[int]) – Maximum number of cuda graphs to capture, where the cuda graph batch sizes range from 1 to max_active_requests (as computed below). Due to rounding, the actual number of cuda graphs may not equal this argument.

  • materialize_only_last_token_logits (Optional[bool]) – Whether to only materialize logits for the last token. This should be set to False if returning log probs.

  • mamba_inference_state_config (Optional[MambaInferenceStateConfig]) – The Mamba inference state config if the model is a hybrid model.

  • use_cuda_graphs_for_non_decode_steps (bool) – If True, use cuda graphs for non-decode engine steps.

  • unified_memory_level (Optional[int]) – Set unified memory usage within the dynamic inference context. The levels are: 0) no unified memory, 1) allocate memory_buffer in unified memory. Eventually, additional levels will be included to control other tensors within the context.

  • use_flashinfer_fused_rope (bool) – If True, use flashinfer’s fused rope implementation. If None, defaults to using flash-infer if available.

  • metrics_writer (Optional['WandbModule']) – Wandb module for writing metrics.

  • request_metadata_types (Optional[List[Tuple[str, torch.dtype, bool]]]) – A list of the per-request metadata types to track. Each entry is a tuple consisting of the string label, the target dtype, and whether to store the data on GPU.

Initialization

Parameters:

materialize_only_last_token_logits (bool) – If True, only the last-token logits will be extracted during decode

DEFAULT_MAX_TOKENS#

16384

TOKEN_ROUNDER#

64

REQUEST_ROUNDER#

4

allocate_all_tensors(*, is_init: bool) None#

Allocate GPU state.

This method is used for both 1) initial allocation, and 2) resuming the GPU state after a suspend.

Parameters:

is_init (bool) – True if this is being called from __init__().

deallocate_all_tensors()#

Deallocate GPU state.

This method is used for suspending the dynamic engine.

classmethod round_up_tokens(value, tp_size=None)#

Round up to nearest multiple of TOKEN_ROUNDER (above) that is also divisible by tensor model parallel size.

classmethod from_config(
inference_config: megatron.core.inference.model_inference_wrappers.inference_wrapper_config.InferenceWrapperConfig,
model,
max_batch_size: int,
buffer_size_gb: float = 40,
num_cuda_graphs: int = None,
mamba_inference_state_config: Optional[core.inference.contexts.attention_context.mamba_metadata.MambaInferenceStateConfig] = None,
)#

Instantiate a DynamicInferenceContext from a TransformerConfig and an InferenceWrapperConfig.

classmethod round_up_requests(value, tp_size=None)#

Round up to nearest multiple of REQUEST_ROUNDER (above) that is also divisible by tensor model parallel size.

classmethod round_up(value)#

Deprecated in favor of round_up_tokens and round_up_requests.

is_static_batching() bool#

Is static batching? False.

is_decode_only() bool#

Return if this iteration we run decode only implementation.

using_cuda_graph_this_step() bool#

Returns True if cuda graphs are being used for this step.

has_unfinished_requests() bool#

Test if any requests remain.

cu_query_lengths() Tuple[torch.Tensor, int]#

Cumulative query sequence lengths.

cu_kv_lengths() Tuple[torch.Tensor, torch.Tensor, int]#

Cumulative key/value sequence lengths.

get_active_sequence_lengths() torch.Tensor#

Total sequence length (query + key) for active requests.

get_max_sequence_lengths() torch.Tensor#

Maximum sequence length for active requests.

get_active_request_count()#

Returns the current number of active requests.

append_key_value_cache(
layer_number: int,
key: torch.Tensor,
value: torch.Tensor,
) None#

Append to KV cache.

Parameters:
  • layer_number (int) – Layer number.

  • key (Tensor) – Key tensor.

  • value (Tensor) – Value tensor.

key_value_cache(
layer_number: int,
) Tuple[torch.Tensor, torch.Tensor]#

Read from KV cache.

Parameters:

layer_number (int) – Layer number.

Returns:

(Tuple[Tensor, Tensor]) The key and value pointer tensors that point to blocks within the block-level memory buffer.

mamba_states_cache(
layer_number: int,
) Tuple[torch.Tensor, torch.Tensor]#

Returns the Mamba state tensors for the given layer.

apply_fused_qk_rotary_emb(
query: torch.Tensor,
key: torch.Tensor,
cos_sin_emb: torch.Tensor,
config: megatron.core.transformer.TransformerConfig,
) Tuple[torch.Tensor, torch.Tensor]#

Apply rotary embedding to query and key tensors using flashinfer’s fused rope.

Parameters:
  • query (Tensor) – Query tensor.

  • key (Tensor) – Key tensor.

  • cos_sin_emb (Tensor) – Rotary embeddings.

  • config (TransformerConfig) – Transformer config.

Returns:

(Tuple[Tensor, Tensor]) Query and Key tensors after applying rotary embeddings.

apply_rotary_emb_query(
query: torch.Tensor,
query_emb: torch.Tensor,
config: megatron.core.transformer.TransformerConfig,
cu_seqlens_q: torch.Tensor,
cp_group: torch.distributed.ProcessGroup,
mscale: float = 1.0,
) torch.Tensor#

Apply rotary embedding to query tensor.

Parameters:
  • query (Tensor) – Query tensor.

  • query_emb (Tensor) – Query rotary embeddings.

  • config (TransformerConfig) – Transformer config.

  • cu_seqlens_q (Tensor) – Cumulative sequence lengths.

  • cp_group (torch.distributed.ProcessGroup) – Process group for context parallel.

Returns:

(Tensor) Query tensor after applying rotary embeddings.

apply_rotary_emb_key(
key: torch.Tensor,
key_emb: torch.Tensor,
config: megatron.core.transformer.TransformerConfig,
cp_group: torch.distributed.ProcessGroup,
mscale: float = 1.0,
) torch.Tensor#

Apply rotary embedding to key tensor.

Parameters:
  • key (Tensor) – Key tensor.

  • key_emb (Tensor) – Key rotary embeddings.

  • config (TransformerConfig) – Transformer config.

  • cp_group (torch.distributed.ProcessGroup) – Process group for context parallel.

Returns:

(Tensor) Key tensor after applying rotary embeddings.

reset_attention_state() None#

Reset state used within attention, after each step.

reset_mamba_state() None#

Reset state used within Mamba layers.

add_dummy_requests_parallel(
requests: Sequence[megatron.core.inference.inference_request.DynamicInferenceRequest],
*,
count_as_prefill: bool = True,
) None#

Fast path to add dummy requests without allocating real KV blocks.

add_dummy_requests_for_cudagraph_capture(
graph_dimensions: megatron.core.inference.batch_dimensions_utils.InferenceBatchDimensions,
) None#

Adds dummy requests to reflect the number of prefill and decode requests in the graph config. These are using during cuda graph captures.

property num_decode_requests: int#

Returns the number of decode requests.

initialize_attention_state(
*,
construct_graph_dimensions: Optional[megatron.core.inference.batch_dimensions_utils.InferenceBatchDimensions] = None,
) None#

Initialize attention state so that every layer can use it.

Parameters:

construct_graph_dimensions (Optional[InferenceBatchDimensions]) – The graph config to use for constructing the cuda graphs.

Returns:

None.

reset() None#

Reset entire context.

This method does:

  • Reset active/paused request/token counts to zero.

  • Reset available blocks to entire memory.

  • Reset other tensors to zeros (unncessary, just or sanity checking).

This method is useful after cuda graph warmup iterations, where the context’s memory buffer is referenced by the cuda graph system and cannot be deallocated.

current_input_and_position_ids(
*,
num_warmup_tokens: Optional[int] = None,
) Tuple[torch.Tensor, torch.Tensor]#

Flattened input and position IDs for forward pass.

Parameters:

num_warmup_tokens (Optional[int]) – Number of tokens to return for warming up cuda graphs. Must be less than or equal to max_tokens.

Returns:

(Tuple[Tensor, Tensor]) Flattened active input and position IDs.

last_token_logits(logits: torch.Tensor) torch.Tensor#

Last tokens of logits.

Parameters:

logits (Tensor) – Output logits of forward pass.

Returns:

(Tensor) Last token logits.

check_availability(
req: megatron.core.inference.inference_request.DynamicInferenceRequest,
)#

Check if the request can be added to the context.

add_request(
req: megatron.core.inference.inference_request.DynamicInferenceRequest,
chunk_length: Optional[int] = None,
) None#

Add request to context. At this stage, we assume that the request is valid and can be added, as the checks are done in the schedule function.

Parameters:
  • req (DynamicInferenceRequest) – Request to add.

  • chunk_length (Optional[int]) – Length of chunk to add. If None, the request will be fully added.

Returns:

None

_move_book_keeping_tensors(src_idxs, dst_idxs, next_tokens)#

Move all the relevent booking tensors with src idxs to dst idxs

_swap_book_keeping_tensors(src_idxs, dst_idxs, next_tokens)#

Swaps all the relevent booking tensors with src idxs to dst idxs

get_index_of_chunked_prefill_request() int#

Get the index of the chunked prefill request in the context.

Returns:

(int) Index of the chunked prefill request, or -1 if none exists.

update_requests(
active_requests_mask: torch.Tensor,
new_tokens: torch.Tensor,
) torch.Tensor#

Update context state after calling engine.step().

This method is responsible for:

  • Update prefill requests to decode requests.

  • Persist decode requests as decode requests.

  • Terminate requests by length or termination id.

Note: All bookkeeping tensors (i.e., self.request_*) are laid out contiguously, with a conceptual division between paused requests on the ‘left’ (or, lower indices) and active requests in the ‘middle’ (or, middle indices) and completed requests on the ‘right’ (or, higher indices). The integers paused_request_count and total_request_count are used to track the boundaries between these request groups.

  • 0:paused_request_count -> paused requests

  • paused_request_count:total_request_count -> active requests

  • total_request_count:max_active_requests -> completed requests are moved here. The reason for maintaining contiguous tensors rather than multiple smaller (e.g., per-group or per-request) tensors is for both 1) speed (avoid unnecessary tensor allocations), and 2) compatibility with the Flash Attention kernels, which packed contiguous tensors.

The following happens in this code :

  1. The active token mask tells us which requests are still active and which are completed

  2. If no paused requests are present and no active requests we release all memory and reset.

  3. Concatenate the paused tokens to the active tokens

  4. For the finished requests we release memory blocks and move them to the right

  5. We identify requests that require a new block and add them to the paused requests (i.e move them left)

  6. We determine how many requests we can resume and resume them

  7. We make changes to the request book keeping tesnsors and setup the tokens for next iteration

  8. We resume those requests by assigning blocks and updating bookkeeping tensors

  9. We make relevant changes to the token bookkeeping tensors

Parameters:
  • active_requests_mask (Tensor) – 1D Mask tensor marking active requests.

  • new_tokens (Tensor) – Newly sampled tokens, with one token per active request.

Returns:

(Tensor) Newly paused request IDs.

calculate_log_probs(
logits: torch.Tensor,
new_tokens: torch.Tensor,
only_last_token_logits: Optional[bool] = False,
) Tuple[List[List[float]], torch.Tensor]#

Calculate log probs for all active requests and return them.

TODO: @wdykas support top-n log probs.

Parameters:
  • logits (Tensor) – Raw model output logits with shape [1, sequence_length, vocab_size].

  • new_tokens (Tensor) – The newly sampled tokens.

  • only_last_token_logits (bool) – If set, the logits are from only the last token in each request

Returns:

List of lists where each inner list contains log probs for a request in the same order as the active requests (from paused_request_count to total_request_count). log_probs (Tensor): Used to compute top n logprobs later if required.

get_kvcache_utilization_stats() dict#

Compute KV cache buffer utilization stats for the current step.

Returns a dictionary with counts and percentages for both allocated block usage (overall buffer occupancy) and active usage (blocks referenced by currently active requests this step).

Returns:

{ ‘total_blocks’: int, ‘allocated_blocks’: int, ‘active_unique_blocks’: int, ‘allocated_utilization’: float, ‘active_utilization’: float, ‘active_request_count’: int, ‘paused_request_count’: int, ‘gtd_block_count’: int, }

maybe_initialize_symmetric_memory()#

Initializes symmetric memory for inference, if not already initialized