nemo_rl.models.generation.vllm#

Module Contents#

Classes#

API#

class nemo_rl.models.generation.vllm.VllmSpecificArgs#

Bases: typing.TypedDict

tensor_parallel_size: int#

None

pipeline_parallel_size: int#

None

gpu_memory_utilization: float#

None

max_model_len: int#

None

skip_tokenizer_init: bool#

None

async_engine: bool#

None

load_format: NotRequired[str]#

None

precision: NotRequired[str]#

None

enforce_eager: NotRequired[bool]#

None

class nemo_rl.models.generation.vllm.VllmConfig#

Bases: nemo_rl.models.generation.interfaces.GenerationConfig

vllm_cfg: nemo_rl.models.generation.vllm.VllmSpecificArgs#

None

vllm_kwargs: NotRequired[dict[str, Any]]#

None

class nemo_rl.models.generation.vllm.VllmGenerationWorker(
config: nemo_rl.models.generation.vllm.VllmConfig,
bundle_indices: Optional[list[int]] = None,
fraction_of_gpus: float = 1.0,
seed: Optional[int] = None,
)#

Initialization

Initialize a vLLM worker for distributed inference.

Parameters:
  • config – Configuration dictionary for the policy

  • bundle_indices – List of local bundle indices within a node for parallelism. Only needed for the first worker in each tied worker group.

  • fraction_of_gpus – Fraction of GPUs to use for this worker

  • seed – Random seed for initialization

__repr__() str#

Customizes the actor’s prefix in the Ray logs.

This makes it easier to identify which worker is producing specific log messages.

static configure_worker(
num_gpus: int | float,
bundle_indices: Optional[tuple[int, list[int]]] = None,
) tuple[dict[str, Any], dict[str, str], dict[str, Any]]#

Provides complete worker configuration for vLLM tensor and pipeline parallelism.

This method configures the worker based on its role in tensor and pipeline parallelism, which is determined directly from the bundle_indices parameter.

Parameters:
  • num_gpus – Original GPU allocation for this worker based on the placement group

  • bundle_indices – Tuple of (node_idx, local_bundle_indices) for parallelism (if applicable)

Returns:

  • ‘resources’: Resource allocation (e.g., num_gpus)

  • ’env_vars’: Environment variables for this worker

  • ’init_kwargs’: Parameters to pass to init of the worker

Return type:

tuple with complete worker configuration

post_init()#
async post_init_async()#
init_collective(
rank_prefix: int,
ip: str,
port: int,
world_size: int,
) None#
async init_collective_async(
rank_prefix: int,
ip: str,
port: int,
world_size: int,
) None#
llm()#
is_alive()#

Check if the worker is alive.

_merge_stop_strings(batch_stop_strings)#
_build_sampling_params(
*,
greedy: bool,
stop_strings,
max_new_tokens: Optional[int] = None,
)#
generate(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
greedy: bool = False,
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec]#

Generate a batch of data using vLLM generation.

Parameters:
  • data – BatchedDataDict containing input_ids and input_lengths tensors

  • greedy – Whether to use greedy decoding instead of sampling

Returns:

  • output_ids: input + generated token IDs with proper padding

  • logprobs: Log probabilities for tokens

  • generation_lengths: Lengths of each response

  • unpadded_sequence_lengths: Lengths of each input + generated sequence

Return type:

BatchedDataDict conforming to GenerationOutputSpec

async generate_async(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
greedy: bool = False,
) AsyncGenerator[tuple[int, nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec]], None]#

Generate a batch of data using vLLM’s AsyncLLMEngine, yielding results as they are ready.

Parameters:
  • data – BatchedDataDict with input_ids and input_lengths

  • greedy – Whether to use greedy decoding instead of sampling

Yields:

Tuple of (original_index, BatchedDataDict conforming to GenerationOutputSpec for the single sequence)

generate_text(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
greedy: bool = False,
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec]#

Generate text responses using vLLM generation.

Parameters:
  • data – BatchedDataDict containing prompts with text strings

  • greedy – Whether to use greedy decoding instead of sampling

Returns:

  • texts: List of generated text responses

Return type:

BatchedDataDict containing

async generate_text_async(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
greedy: bool = False,
) AsyncGenerator[tuple[int, nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec]], None]#

Generate text responses asynchronously, yielding results as they are ready.

Parameters:
  • data – BatchedDataDict containing prompts with text strings

  • greedy – Whether to use greedy decoding instead of sampling

Yields:

Tuple of (original_index, BatchedDataDict containing single text response)

shutdown() bool#

Clean up vLLM resources.

report_device_id() list[str]#

Report device ID from the vLLM worker.

async report_device_id_async() list[str]#

Async version of report_device_id.

prepare_refit_info(state_dict_info: dict[str, Any]) None#

Prepare the info for refit.

async prepare_refit_info_async(
state_dict_info: dict[str, Any],
) None#

Async version of prepare_refit_info.

update_weights_from_ipc_handles(
ipc_handles: dict[str, Any],
) bool#

Update weights from IPC handles by delegating to the vLLM Worker implementation.

Parameters:

ipc_handles (dict) – Dictionary mapping device UUIDs (str) to parameter IPC handles.

Returns:

True if weights were successfully updated, False otherwise.

Return type:

bool

async update_weights_from_ipc_handles_async(
ipc_handles: dict[str, Any],
) bool#

Async version of update_weights_from_ipc_handles.

Parameters:

ipc_handles (dict) – Dictionary mapping device UUIDs (str) to parameter IPC handles.

Returns:

True if weights were successfully updated, False otherwise.

Return type:

bool

update_weights_from_collective() bool#

Update the model weights from collective communication.

async update_weights_from_collective_async() bool#

Async version of update_weights_from_collective.

reset_prefix_cache()#

Reset the prefix cache of vLLM engine.

async reset_prefix_cache_async()#

Async version of reset_prefix_cache.

sleep()#

Put the vLLM engine to sleep.

async sleep_async()#

Async version of sleep.

wake_up(**kwargs)#

Wake up the vLLM engine.

async wake_up_async(**kwargs)#

Async version of wake_up.

start_gpu_profiling() None#

Start GPU profiling.

stop_gpu_profiling() None#

Stop GPU profiling.

class nemo_rl.models.generation.vllm.VllmGeneration(
cluster: nemo_rl.distributed.virtual_cluster.RayVirtualCluster,
config: nemo_rl.models.generation.vllm.VllmConfig,
name_prefix: str = 'vllm_policy',
workers_per_node: Optional[Union[int, list[int]]] = None,
)#

Bases: nemo_rl.models.generation.interfaces.GenerationInterface

_get_tied_worker_bundle_indices(
cluster: nemo_rl.distributed.virtual_cluster.RayVirtualCluster,
) list[tuple[int, list[int]]]#

Calculate bundle indices for tensor and pipeline parallel workers.

Handles both unified placement groups (for cross-node model parallelism) and per-node placement groups (for node-local model parallelism).

_report_device_id() list[list[str]]#

Report the device ID of vllm workers.

_post_init()#
init_collective(
ip: str,
port: int,
world_size: int,
) list[ray.ObjectRef]#

Initialize the collective communication.

generate(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
greedy: bool = False,
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec]#

Generate a batch of data using vLLM.

generate_text(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
greedy: bool = False,
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec]#

Generate text responses using vLLM.

async _async_generate_base(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
method_name: str,
data_validation_fn,
greedy: bool = False,
) AsyncGenerator[tuple[int, nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec]], None]#

Base async generation method that handles common worker management logic.

Parameters:
  • data – Input data for generation

  • method_name – Name of the worker method to call (‘generate_async’ or ‘generate_text_async’)

  • data_validation_fn – Function to validate input data

  • greedy – Whether to use greedy decoding

Yields:

Tuple of (original_index, BatchedDataDict containing generation result)

async generate_text_async(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
greedy: bool = False,
) AsyncGenerator[tuple[int, nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec]], None]#

Generate text responses asynchronously, yielding results as they are ready.

Parameters:
  • data – BatchedDataDict containing prompts with text strings

  • greedy – Whether to use greedy decoding instead of sampling

Yields:

Tuple of (original_index, BatchedDataDict containing single text response)

async generate_async(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
greedy: bool = False,
) AsyncGenerator[tuple[int, nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec]], None]#

Generate responses asynchronously, yielding individual samples as they complete.

This method provides per-sample streaming across all workers, yielding each sample result as soon as it’s ready, regardless of which worker processed it.

prepare_for_generation(
*args: Any,
**kwargs: Any,
) bool#

Wake workers up for colocated inference.

finish_generation(*args: Any, **kwargs: Any) bool#

Sleep workers and reset prefix cache.

shutdown() bool#

Shut down all vLLM workers and clean up resources.

prepare_refit_info(state_dict_info: dict[str, Any]) None#

Prepare the info for refit.

update_weights_from_ipc_handles(
ipc_handles: dict[str, Any],
) bool#

Update weights of the policy using IPC handles, considering tensor parallelism.

For tp > 1, only the leader in each tensor parallel tied worker group will update weights.

Parameters:

ipc_handles (dict) – Dictionary mapping device UUIDs (str) to parameter IPC handles.

Returns:

True if weights were successfully updated, False otherwise.

Return type:

bool

update_weights_from_collective() list[ray.ObjectRef]#

Update weights of the policy using collective communication.

start_gpu_profiling() None#

Start GPU profiling.

stop_gpu_profiling() None#

Stop GPU profiling.

__del__() None#

Shuts down the worker groups when the object is deleted or is garbage collected.

This is an extra safety net in case the user forgets to call shutdown() and the pointer to the object is lost due to leaving a function scope. It’s always recommended that the user calls shutdown().