core.inference.inference_client#
Module Contents#
Classes#
An asynchronous client for communicating with an inference coordinator service. |
API#
- class core.inference.inference_client.InferenceClient(inference_coordinator_address: str)#
An asynchronous client for communicating with an inference coordinator service.
This client uses ZeroMQ (ZMQ) for messaging and MessagePack (msgpack) for serialization. It is designed to work within an asyncio event loop. It can submit inference requests, listen for completed results, and send control signals (e.g., pause, stop) to the inference engines.
The client operates by connecting a ZMQ DEALER socket to the inference coordinator’s ROUTER socket. Requests are sent with a unique ID, and an
asyncio.Futureis created for each request. A background task listens for replies from the coordinator, and when a reply is received, it resolves the corresponding future with the result... attribute:: context
The ZeroMQ context.
- Type:
zmq.Context
.. attribute:: socket
The ZMQ DEALER socket used for communication.
- Type:
zmq.Socket
.. attribute:: completion_futures
A dictionary mapping request IDs to the asyncio Future objects that will hold the results.
- Type:
dict[int, asyncio.Future]
.. attribute:: next_request_id
A counter for generating unique request IDs.
- Type:
int
.. attribute:: listener_task
The background task that listens for completed requests.
- Type:
asyncio.Task
Initialization
Initializes the InferenceClient.
- Parameters:
inference_coordinator_address (str) – The address on which the inference coordinator is listening.
- add_request(
- prompt: Union[str, List[int]],
- sampling_params: megatron.core.inference.sampling_params.SamplingParams,
Submits a new inference request to the coordinator.
This method sends the prompt and sampling parameters to the inference coordinator. It immediately returns an asyncio.Future, which can be awaited to get the result of the inference request when it is complete.
- Parameters:
prompt (str) – The input prompt to send to the language model.
sampling_params – An object containing the sampling parameters for text generation (e.g., temperature, top_p). It must have a
serialize()method.
- Returns:
A future that will be resolved with a
DynamicInferenceRequestRecordobject containing the completed result.- Return type:
asyncio.Future
- async _recv_task()#
Listens for completed inference requests from the coordinator.
This coroutine runs in an infinite loop, continuously polling the socket for data. When a request reply is received, it unpacks the message, finds the corresponding Future using the request ID, and sets the result. Other control packets are handled appropriately.
This method is started as a background task by the
start()method.
- _connect_with_inference_coordinator()#
Performs the initial handshake with the inference coordinator.
Sends a CONNECT signal and waits for a CONNECT_ACK reply to ensure the connection is established and acknowledged by the coordinator.
- start(loop: Optional[asyncio.AbstractEventLoop] = None)#
Connects to the coordinator and starts the background listener task.
This must be called before submitting any requests. It handles the initial handshake and spawns the
listen_for_completed_requestscoroutine.
- _send_signal_to_engines(signal)#
Sends a generic control signal to the inference coordinator.
- Parameters:
signal – The signal to send, typically a value from the
Headersenum.
- pause_engines()#
Sends PAUSE to all engines via coordinator.
The coordinator broadcasts PAUSE. Each engine reaches EP consensus, then synchronizes via a world-wide barrier before transitioning to PAUSED. Callers should await engine.paused for confirmation.
- unpause_engines() None#
Sends UNPAUSE to all engines. No synchronization needed.
- increment_staleness()#
Sends a signal to increment staleness on all in-flight requests.
- suspend_engines()#
Sends SUSPEND to all engines via coordinator. Requires PAUSED.
Callers should await engine.suspended for confirmation.
- resume_engines()#
Sends RESUME to all engines via coordinator. Requires SUSPENDED.
Callers should await engine.paused (or engine.running after UNPAUSE) for confirmation.
- stop_engines()#
Sends STOP to all engines via coordinator. Requires PAUSED or SUSPENDED.
Callers should await engine.stopped for confirmation. Does not affect the coordinator.
- shutdown_coordinator()#
Tells the coordinator process to exit its main loop.
Does not affect the engines.
- stop()#
Stops the client and cleans up all resources.
This method cancels the background listener task, closes the ZMQ socket, and terminates the ZMQ context. It should be called when the client is no longer needed to ensure a graceful shutdown.