core.inference.engines.dynamic_engine#
Module Contents#
Classes#
Entry in the engine’s |
|
The dynamic inference engine. |
Functions#
Convert a byte count to a human-readable string in tb, gb, mb, kb, or bytes. |
API#
- exception core.inference.engines.dynamic_engine.EngineSuspendedError#
Bases:
ExceptionEngine is currently suspended and not performing steps.
Initialization
Initialize self. See help(type(self)) for accurate signature.
- core.inference.engines.dynamic_engine.format_mem_bytes(mem_bytes)#
Convert a byte count to a human-readable string in tb, gb, mb, kb, or bytes.
- class core.inference.engines.dynamic_engine.RequestEntry#
Entry in the engine’s
self.requestsdict.- record: megatron.core.inference.inference_request.DynamicInferenceRequestRecord#
None
- future: asyncio.Future#
None
- class core.inference.engines.dynamic_engine.DynamicInferenceEngine(
- controller: megatron.core.inference.text_generation_controllers.text_generation_controller.TextGenerationController,
- context: megatron.core.inference.contexts.dynamic_context.DynamicInferenceContext,
- enable_cuda_graph: Optional[bool] = None,
- random_seed: Optional[int] = None,
- *,
- track_paused_request_events: bool = False,
- enable_chunked_prefill: bool = True,
- inference_logging_step_interval: int = 0,
- pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
Bases:
megatron.core.inference.engines.abstract_engine.AbstractEngineThe dynamic inference engine.
This engine allows requests of varying length to be dynamically added and removed in each inference step. In contrast to the static engine that has a set batch size and sequence length during the forward pass, each request in the dynamic engine can have different current prompt and output length at any given step, and the processing is restricted only by a max number of total tokens across all requests.
- Parameters:
text_generation_controller (TextGenerationController) – A text generation controller that will be used to define how to preprocess prompts, generate outputs and detokenizer the output tokens.
inference_context (DynamicInferenceContext) – Context for managing in-flight batching and a dynamic block-level KV cache (similar to paged attention).
random_seed (Optional[int]) – Use a random seed if you want deterministic results. Defaults to None.
inference_logging_step_interval (int) – The step interval at which to log
0 (inference metrics to wandb. Defaults to)
logging. (which means no)
Initialization
- reset() None#
Reset by removing all requests and reset all state.
- create_cuda_graphs(reset_context: bool = True)#
Create cuda graphs.
This method iterates the dynamic context’s
cuda_graph_request_countsto record and capture cuda graphs.- Parameters:
reset_context (bool) – Whether to reset the context after building cuda graphs.
- async start_listening_to_data_parallel_coordinator(
- inference_coordinator_port: int,
- launch_inference_coordinator: bool = True,
- *,
- loop: Optional[asyncio.AbstractEventLoop] = None,
Initializes ZMQ communication to connect the engine with an inference coordinator.
This asynchronous method sets up the distributed communication infrastructure that allows this inference engine to act as a worker under a central
InferenceCoordinator. It configures different ZMQ socket patterns based on the rank’s role within the distributed topology.Note that this method must be called on all ranks, as it uses blocking torch broadcasts.
The setup involves two primary roles within each data-parallel group:
MP Coordinator (TP_rank=0, PP_rank=0): This rank connects directly to the central coordinator via a ZMQ
DEALERsocket. It receives requests and uses a ZMQPUB(publisher) socket to broadcast them to all other ranks within its model-parallel (MP) group.MP Workers (all other ranks): These ranks use ZMQ
SUB(subscriber) sockets to listen for requests broadcast by their local MP Coordinator.
This architecture uses TCP sockets for both inter-node and intra-node broadcasts within an MP group.
Finally, after setting up the communication channels and ensuring all ranks are synchronized, this method starts the main engine processing loop (
self.run_engine) as a background asyncio task.- Parameters:
inference_coordinator_port (int) – The network port where the central
InferenceCoordinatoris or will be listening.launch_inference_coordinator (bool, optional) – If True, the global rank 0 process will spawn and manage the
InferenceCoordinatorprocess. Defaults to True.
- static suspend_resume_ctx(key: str, *, unified_memory_level: int) None#
Context manager for of suspending and resuming the engine.
This context manager records the time and memory usage when suspending and resuming the context. TODO(@lmcafee): add argument to optionally return nullcontext, to avoid overhead.
- Parameters:
key (str) – Key that identifies caller (e.g., ‘suspend’ or ‘resume’).
- Returns:
None.
- suspend()#
Suspend engine by deallocating context’s GPU state.
- resume()#
Resume engine by reallocating context’s GPU state.
- async _notify_cond_for_new_request()#
Helper function to notify condition variable when a new request is added.
- has_unfinished_requests() bool#
Test if context contains unfinished requests.
- get_request(
- request_id: int,
Get most recent request from a request record.
- Parameters:
request_id (int) – Request id.
- Returns:
(DynamicInferenceRequest) The most recent request in the record.
- _add_request(
- request: megatron.core.inference.inference_request.DynamicInferenceRequest,
- add_request(
- request_id: int,
- prompt: Union[str, List[int], torch.Tensor],
- sampling_params: Optional[megatron.core.inference.sampling_params.SamplingParams] = None,
Add request to inference context.
- Parameters:
request_id (int) – Unique ID of request.
prompt (Union[str, Tensor]) – Prompt as either a text string or token IDs.
sampling_params (Optional[SamplingParams]) – Sampling parameters for the request.
- Returns:
Returns an asyncio
Future[DynamicInferenceRequest]for the user to wait on.
- post_process_requests(
- request_ids: torch.Tensor,
- finished_request_ids: torch.Tensor,
- step_time: float,
- sample: torch.Tensor,
- log_probs: torch.Tensor,
- top_n_logprobs: Optional[Dict[int, List[Tuple[torch.Tensor, torch.Tensor]]]] = None,
Handles post-processing for requests after a step.
- Parameters:
request_ids (torch.Tensor) – A list of request_ids
finished_request_ids (torch.Tensor) – A list of finished request ids
step_time (float) – The latency of the last step
sample – (torch.Tensor): The newly generated tokens for each request
log_probs – (List): Log probs for each request
top_n_logprobs – (Dict): Top-n log probs for each request. Maps request_idx to list of (top_n_logprobs, top_n_indices) tuples.
- Returns:
A list of active requests and completed requests as
DynamicInferenceRequestobjects
- schedule_waiting_requests()#
Tries to schedule any requests in the waiting pool.
- schedule_non_chunked_prefill()#
Perform the same original scheduling logic for non-chunked runs
- schedule_chunked_prefill()#
This function schedules chunked prefill requests. Invariant: - There are at most one chunked prefill request in the waiting pool, which should be the head - There are at most one chunked prefill request in the context, which should be the last active request - context.chunked_prefill_request_id == -1 if no chunked prefill request is scheduled, otherwise it is the request id of the chunked prefill request - For each request, finished_chunk_token_count is the number of tokens that have been prefilled for this request, non-zero means it is during a chunked prefill - For each request, remaining_prompt_tokens holds the unprefilled prompt tokens
- async async_forward() Tuple[Dict, Dict, float, int]#
Uses
asynciofor continuous generation. Sleeps when no requests are available, until new requests have been added.- Returns:
step_result (Optional[Dict]): The result of the step. context_state (Dict): A tuple consisting of the state of the context. is_decode_only, total/paused request count, active token count. step_time (float): How long this step took.
- Return type:
A tuple comprised of
- async async_bookkeep(
- step_result: Optional[Dict],
- context_state: Dict,
- step_time: float,
- step_count: int,
Uses
asynciofor continuous bookkeeping.- Parameters:
step_result (Optional[Dict]) – The result of the step.
context_state (Dict) – is_decode_only, total/paused request count, active token count.
step_time (float) – How long this step took.
step_count (int) – The count of the step.
- Returns:
active_requests (List): Requests that ran in the last step and are still active. finished_requests (List): Requests that ran in the last step and have now finished. step_time (float): The step time in seconds. cuda_graph_request_count (int): The CUDA graph batch size matching this step.
- Return type:
A dictionary containing
- async async_step() Tuple[List[megatron.core.inference.inference_request.DynamicInferenceRequest], List[megatron.core.inference.inference_request.DynamicInferenceRequest], float]#
Wrapper for controller.generate_output_tokens_dynamic_batch(), to match vLLM API. Uses
asynciofor continuous generation which allows this method to sleep and wake up when new requests are available.- Returns:
Requests that ran in the last step and are still active.
Requests that ran in the last step and have now finished.
The step time in seconds.
- Return type:
A tuple comprised of
- step_modern() Tuple[List[megatron.core.inference.inference_request.DynamicInferenceRequest], List[megatron.core.inference.inference_request.DynamicInferenceRequest], float]#
Synchronous wrapper for
self.async_step.
- step_legacy(
- sampling_params: megatron.core.inference.sampling_params.SamplingParams,
Synchronous wrapper for
self.async_step.
- step#
None
- generate(
- prompts: List[str],
- sampling_params: Optional[megatron.core.inference.sampling_params.SamplingParams] = SamplingParams(),
Generates completions for a static list of prompts.
- schedule_requests() int#
Drains the ZMQ socket for a batch of requests and adds them to the engine.
This method is a collective and synchronous operation that must be called by all ranks in a Model Parallel (MP) group at the same time. It ensures that all ranks process the exact same batch of incoming requests and control signals.
The synchronization works as follows:
The MP rank 0 drains all pending messages from its subscriber socket in a non-blocking manner.
MP rank 0 then broadcasts the number of messages it received to all other ranks in its MP group using a dedicated publisher socket.
The other MP ranks wait to receive this count, and then receive exactly that many messages from their subscriber sockets.
Once all ranks have the same batch of messages, they are unpacked and processed. New requests are added to the engine’s queue, and control signals (PAUSE, UNPAUSE, SUSPEND, RESUME, STOP) update the engine’s internal state.
.. note::
This function is synchronous and must be called collectively by all ranks in a MP group. It should not be launched in a separate coroutine to ensure all ranks execute it in lockstep before proceeding to the next engine step.
- Returns:
The number of messages that were received and processed in this batch.
- Return type:
int
- stop()#
Stops the inference engine by terminating the inference coordinator process if it exists, and destroys the model parallel state. This method ensures that any running inference coordinator subprocess is properly terminated, and cleans up resources associated with model parallelism.
- async run_engine(*, loop: Optional[asyncio.AbstractEventLoop] = None)#
Continually steps the engine asynchronously.
- async _ep_group_has_work(local_work: int) bool#
Determines if there are some pending requests in the expert parallel group this rank is a part of.
- Parameters:
local_work (int) – The local work count for this rank. This is a sum of active
requests. (and waiting)
- Returns:
True if there is some work in the EP group, False otherwise.
- Return type:
bool
- async run_engine_with_coordinator(
- *,
- loop: Optional[asyncio.AbstractEventLoop] = None,
Continually steps the engine asynchronously.