core.inference.inference_client#

Module Contents#

Classes#

InferenceClient

An asynchronous client for communicating with an inference coordinator service.

API#

class core.inference.inference_client.InferenceClient(inference_coordinator_port: int)#

An asynchronous client for communicating with an inference coordinator service.

This client uses ZeroMQ (ZMQ) for messaging and MessagePack (msgpack) for serialization. It is designed to work within an asyncio event loop. It can submit inference requests, listen for completed results, and send control signals (e.g., pause, stop) to the inference engines.

The client operates by connecting a ZMQ DEALER socket to the inference coordinator’s ROUTER socket. Requests are sent with a unique ID, and an asyncio.Future is created for each request. A background task listens for replies from the coordinator, and when a reply is received, it resolves the corresponding future with the result.

.. attribute:: context

The ZeroMQ context.

Type:

zmq.Context

.. attribute:: socket

The ZMQ DEALER socket used for communication.

Type:

zmq.Socket

.. attribute:: completion_futures

A dictionary mapping request IDs to the asyncio Future objects that will hold the results.

Type:

dict[int, asyncio.Future]

.. attribute:: next_request_id

A counter for generating unique request IDs.

Type:

int

.. attribute:: listener_task

The background task that listens for completed requests.

Type:

asyncio.Task

Initialization

Initializes the InferenceClient.

Parameters:

inference_coordinator_port (int) – The port number on which the inference coordinator is listening.

add_request(
prompt: Union[str, List[int]],
sampling_params: megatron.core.inference.sampling_params.SamplingParams,
) asyncio.Future#

Submits a new inference request to the coordinator.

This method sends the prompt and sampling parameters to the inference coordinator. It immediately returns an asyncio.Future, which can be awaited to get the result of the inference request when it is complete.

Parameters:
  • prompt (str) – The input prompt to send to the language model.

  • sampling_params – An object containing the sampling parameters for text generation (e.g., temperature, top_p). It must have a serialize() method.

Returns:

A future that will be resolved with a DynamicInferenceRequestRecord object containing the completed result.

Return type:

asyncio.Future

async _recv_task()#

Listens for completed inference requests from the coordinator.

This coroutine runs in an infinite loop, continuously polling the socket for data. When a request reply is received, it unpacks the message, finds the corresponding Future using the request ID, and sets the result. Other control packets are handled appropriately.

This method is started as a background task by the start() method.

_connect_with_inference_coordinator()#

Performs the initial handshake with the inference coordinator.

Sends a CONNECT signal and waits for a CONNECT_ACK reply to ensure the connection is established and acknowledged by the coordinator.

async start(loop: Optional[asyncio.AbstractEventLoop] = None)#

Connects to the coordinator and starts the background listener task.

This method must be awaited before submitting any requests. It handles the initial handshake and spawns the listen_for_completed_requests coroutine.

_send_signal_to_engines(signal)#

Sends a generic control signal to the inference coordinator.

Parameters:

signal – The signal to send, typically a value from the Headers enum.

pause_engines() Awaitable#

Sends a signal to pause all inference engines.

The signal first propagates thru the coordinator to all engines. All engines acknowledge this signal and clear their running flags. The coordinator awaits all acknowledgements before forwarding the ACK back to the client, as well as to the engines. The engines set their paused flags upon seeing the ACK.

Returns:

An awaitable that resolves when all engines have paused.

Return type:

Awaitable

unpause_engines() None#

Sends a signal to unpause all inference engines.

suspend_engines()#

Sends a signal to pause all inference engines.

resume_engines()#

Sends a signal to unpause all inference engines.

stop_engines() Awaitable#

Sends a signal to gracefully stop all inference engines.

The signal first propagates thru the coordinator to all engines. All engines acknowledge this signal and clear their running flags. The coordinator awaits all acknowledgements before forwarding the ACK back to the client, as well as to the engines. The engines set their stopped flags upon seeing the ACK.

Returns:

An awaitable that resolves when all engines have stopped.

Return type:

Awaitable

stop()#

Stops the client and cleans up all resources.

This method cancels the background listener task, closes the ZMQ socket, and terminates the ZMQ context. It should be called when the client is no longer needed to ensure a graceful shutdown.