nemo_rl.models.generation.vllm.vllm_generation
#
Module Contents#
Classes#
API#
- class nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration(
- cluster: nemo_rl.distributed.virtual_cluster.RayVirtualCluster,
- config: nemo_rl.models.generation.vllm.config.VllmConfig,
- name_prefix: str = 'vllm_policy',
- workers_per_node: Optional[Union[int, list[int]]] = None,
Bases:
nemo_rl.models.generation.interfaces.GenerationInterface
- _get_tied_worker_bundle_indices( ) list[tuple[int, list[int]]] #
Calculate bundle indices for tensor and pipeline parallel workers.
Handles both unified placement groups (for cross-node model parallelism) and per-node placement groups (for node-local model parallelism).
- _report_device_id() list[list[str]] #
Report the device ID of vllm workers.
- _post_init()#
- init_collective(
- ip: str,
- port: int,
- world_size: int,
Initialize the collective communication.
- generate(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
- greedy: bool = False,
Generate a batch of data using vLLM.
- generate_text(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
- greedy: bool = False,
Generate text responses using vLLM.
- async _async_generate_base(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
- method_name: str,
- data_validation_fn,
- greedy: bool = False,
Base async generation method that handles common worker management logic.
- Parameters:
data – Input data for generation
method_name – Name of the worker method to call (‘generate_async’ or ‘generate_text_async’)
data_validation_fn – Function to validate input data
greedy – Whether to use greedy decoding
- Yields:
Tuple of (original_index, BatchedDataDict containing generation result)
- async generate_text_async(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
- greedy: bool = False,
Generate text responses asynchronously, yielding results as they are ready.
- Parameters:
data – BatchedDataDict containing prompts with text strings
greedy – Whether to use greedy decoding instead of sampling
- Yields:
Tuple of (original_index, BatchedDataDict containing single text response)
- async generate_async(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
- greedy: bool = False,
Generate responses asynchronously, yielding individual samples as they complete.
This method provides per-sample streaming across all workers, yielding each sample result as soon as it’s ready, regardless of which worker processed it.
- prepare_for_generation(
- *args: Any,
- **kwargs: Any,
Wake workers up for colocated inference.
- finish_generation(*args: Any, **kwargs: Any) bool #
Sleep workers and reset prefix cache.
- shutdown() bool #
Shut down all vLLM workers and clean up resources.
- prepare_refit_info(state_dict_info: dict[str, Any]) None #
Prepare the info for refit.
- update_weights_from_ipc_handles(
- ipc_handles: dict[str, Any],
Update weights of the policy using IPC handles, considering tensor parallelism.
For tp > 1, only the leader in each tensor parallel tied worker group will update weights.
- Parameters:
ipc_handles (dict) – Dictionary mapping device UUIDs (str) to parameter IPC handles.
- Returns:
True if weights were successfully updated, False otherwise.
- Return type:
bool
- update_weights_from_collective() list[ray.ObjectRef] #
Update weights of the policy using collective communication.
- start_gpu_profiling() None #
Start GPU profiling.
- stop_gpu_profiling() None #
Stop GPU profiling.
- __del__() None #
Shuts down the worker groups when the object is deleted or is garbage collected.
This is an extra safety net in case the user forgets to call shutdown() and the pointer to the object is lost due to leaving a function scope. It’s always recommended that the user calls shutdown().