core.inference.inference_client#
Module Contents#
Classes#
An asynchronous client for communicating with an inference coordinator service. |
API#
- class core.inference.inference_client.InferenceClient(inference_coordinator_port: int)#
An asynchronous client for communicating with an inference coordinator service.
This client uses ZeroMQ (ZMQ) for messaging and MessagePack (msgpack) for serialization. It is designed to work within an asyncio event loop. It can submit inference requests, listen for completed results, and send control signals (e.g., pause, stop) to the inference engines.
The client operates by connecting a ZMQ DEALER socket to the inference coordinator’s ROUTER socket. Requests are sent with a unique ID, and an
asyncio.Futureis created for each request. A background task listens for replies from the coordinator, and when a reply is received, it resolves the corresponding future with the result... attribute:: context
The ZeroMQ context.
- Type:
zmq.Context
.. attribute:: socket
The ZMQ DEALER socket used for communication.
- Type:
zmq.Socket
.. attribute:: completion_futures
A dictionary mapping request IDs to the asyncio Future objects that will hold the results.
- Type:
dict[int, asyncio.Future]
.. attribute:: next_request_id
A counter for generating unique request IDs.
- Type:
int
.. attribute:: listener_task
The background task that listens for completed requests.
- Type:
asyncio.Task
Initialization
Initializes the InferenceClient.
- Parameters:
inference_coordinator_port (int) – The port number on which the inference coordinator is listening.
- add_request(
- prompt: Union[str, List[int]],
- sampling_params: megatron.core.inference.sampling_params.SamplingParams,
Submits a new inference request to the coordinator.
This method sends the prompt and sampling parameters to the inference coordinator. It immediately returns an asyncio.Future, which can be awaited to get the result of the inference request when it is complete.
- Parameters:
prompt (str) – The input prompt to send to the language model.
sampling_params – An object containing the sampling parameters for text generation (e.g., temperature, top_p). It must have a
serialize()method.
- Returns:
A future that will be resolved with a
DynamicInferenceRequestRecordobject containing the completed result.- Return type:
asyncio.Future
- async _recv_task()#
Listens for completed inference requests from the coordinator.
This coroutine runs in an infinite loop, continuously polling the socket for data. When a request reply is received, it unpacks the message, finds the corresponding Future using the request ID, and sets the result. Other control packets are handled appropriately.
This method is started as a background task by the
start()method.
- _connect_with_inference_coordinator()#
Performs the initial handshake with the inference coordinator.
Sends a CONNECT signal and waits for a CONNECT_ACK reply to ensure the connection is established and acknowledged by the coordinator.
- async start(loop: Optional[asyncio.AbstractEventLoop] = None)#
Connects to the coordinator and starts the background listener task.
This method must be awaited before submitting any requests. It handles the initial handshake and spawns the
listen_for_completed_requestscoroutine.
- _send_signal_to_engines(signal)#
Sends a generic control signal to the inference coordinator.
- Parameters:
signal – The signal to send, typically a value from the
Headersenum.
- pause_engines() Awaitable#
Sends a signal to pause all inference engines.
The signal first propagates thru the coordinator to all engines. All engines acknowledge this signal and clear their
runningflags. The coordinator awaits all acknowledgements before forwarding the ACK back to the client, as well as to the engines. The engines set theirpausedflags upon seeing the ACK.- Returns:
An awaitable that resolves when all engines have paused.
- Return type:
Awaitable
- unpause_engines() None#
Sends a signal to unpause all inference engines.
- suspend_engines()#
Sends a signal to pause all inference engines.
- resume_engines()#
Sends a signal to unpause all inference engines.
- stop_engines() Awaitable#
Sends a signal to gracefully stop all inference engines.
The signal first propagates thru the coordinator to all engines. All engines acknowledge this signal and clear their
runningflags. The coordinator awaits all acknowledgements before forwarding the ACK back to the client, as well as to the engines. The engines set theirstoppedflags upon seeing the ACK.- Returns:
An awaitable that resolves when all engines have stopped.
- Return type:
Awaitable
- stop()#
Stops the client and cleans up all resources.
This method cancels the background listener task, closes the ZMQ socket, and terminates the ZMQ context. It should be called when the client is no longer needed to ensure a graceful shutdown.