nemo_rl.models.generation.vllm.vllm_worker_async#

Module Contents#

Classes#

API#

class nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker(
config: nemo_rl.models.generation.vllm.config.VllmConfig,
bundle_indices: Optional[list[int]] = None,
fraction_of_gpus: float = 1.0,
seed: Optional[int] = None,
)#

Bases: nemo_rl.models.generation.vllm.vllm_worker.BaseVllmGenerationWorker

Initialization

Initialize a vLLM worker for distributed inference.

Parameters:
  • config – Configuration dictionary for the policy

  • bundle_indices – List of local bundle indices within a node for parallelism. Only needed for the first worker in each tied worker group.

  • fraction_of_gpus – Fraction of GPUs to use for this worker

  • seed – Random seed for initialization

_create_engine(llm_kwargs: dict[str, Any]) None#
async post_init_async()#
async init_collective_async(
rank_prefix: int,
ip: str,
port: int,
world_size: int,
) None#
async generate_async(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
greedy: bool = False,
) AsyncGenerator[tuple[int, nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec]], None]#

Generate a batch of data using vLLM’s AsyncLLMEngine, yielding results as they are ready.

Parameters:
  • data – BatchedDataDict with input_ids and input_lengths

  • greedy – Whether to use greedy decoding instead of sampling

Yields:

Tuple of (original_index, BatchedDataDict conforming to GenerationOutputSpec for the single sequence)

async generate_text_async(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
greedy: bool = False,
) AsyncGenerator[tuple[int, nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec]], None]#

Generate text responses asynchronously, yielding results as they are ready.

Parameters:
  • data – BatchedDataDict containing prompts with text strings

  • greedy – Whether to use greedy decoding instead of sampling

Yields:

Tuple of (original_index, BatchedDataDict containing single text response)

async report_device_id_async() list[str]#

Async version of report_device_id.

async prepare_refit_info_async(
state_dict_info: dict[str, Any],
) None#

Async version of prepare_refit_info.

async update_weights_from_ipc_handles_async(
ipc_handles: dict[str, Any],
) bool#

Async version of update_weights_from_ipc_handles.

Parameters:

ipc_handles (dict) – Dictionary mapping device UUIDs (str) to parameter IPC handles.

Returns:

True if weights were successfully updated, False otherwise.

Return type:

bool

async update_weights_from_collective_async() bool#

Async version of update_weights_from_collective.

async reset_prefix_cache_async()#

Async version of reset_prefix_cache.

async sleep_async()#

Async version of sleep.

async wake_up_async(**kwargs)#

Async version of wake_up.

shutdown() bool#

Clean up vLLM resources.