nemo_rl.models.generation.vllm
#
Module Contents#
Classes#
API#
- class nemo_rl.models.generation.vllm.VllmSpecificArgs[source]#
Bases:
typing.TypedDict
- tensor_parallel_size: int#
None
- gpu_memory_utilization: float#
None
- max_model_len: int#
None
- skip_tokenizer_init: bool#
None
- class nemo_rl.models.generation.vllm.VllmConfig[source]#
Bases:
nemo_rl.models.generation.interfaces.GenerationConfig
- vllm_cfg: nemo_rl.models.generation.vllm.VllmSpecificArgs#
None
- class nemo_rl.models.generation.vllm.VllmGenerationWorker(
- config: nemo_rl.models.generation.vllm.VllmConfig,
- bundle_indices: Optional[list] = None,
- fraction_of_gpus: float = 1.0,
- seed: Optional[int] = None,
Initialization
Initialize a vLLM worker for distributed inference.
- Parameters:
config – Configuration dictionary for the policy
bundle_indices – List of local bundle indices within a node for tensor parallelism. Only needed for the first worker in each tied worker group.
- __repr__()[source]#
Customizes the actor’s prefix in the Ray logs.
This makes it easier to identify which worker is producing specific log messages.
- static configure_worker(
- num_gpus: int | float,
- bundle_indices: Optional[tuple] = None,
Provides complete worker configuration for vLLM tensor parallelism.
This method configures the worker based on its role in tensor parallelism, which is determined directly from the bundle_indices parameter.
- Parameters:
num_gpus – Original GPU allocation for this worker based on the placement group
bundle_indices – Tuple of (node_idx, local_bundle_indices) for tensor parallelism (if applicable)
- Returns:
‘resources’: Resource allocation (e.g., num_gpus)
’env_vars’: Environment variables for this worker
’init_kwargs’: Parameters to pass to init of the worker
- Return type:
tuple with complete worker configuration
- generate(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
- greedy: bool = False,
Generate a batch of data using vLLM generation.
- Parameters:
data – BatchedDataDict containing input_ids and input_lengths tensors
greedy – Whether to use greedy decoding instead of sampling
- Returns:
output_ids: input + generated token IDs with proper padding
logprobs: Log probabilities for tokens
generation_lengths: Lengths of each response
unpadded_sequence_lengths: Lengths of each input + generated sequence
- Return type:
BatchedDataDict conforming to GenerationOutputSpec
- generate_text(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
- greedy: bool = False,
Generate text responses using vLLM generation.
- Parameters:
data – BatchedDataDict containing prompts with text strings
greedy – Whether to use greedy decoding instead of sampling
- Returns:
texts: List of generated text responses
- Return type:
BatchedDataDict containing
- update_weights_from_ipc_handles(ipc_handles)[source]#
Update weights from IPC handles by delegating to the vLLM Worker implementation.
- Parameters:
ipc_handles (dict) – Dictionary mapping device UUIDs to parameter IPC handles.
- Returns:
True if weights were successfully updated, False otherwise.
- Return type:
bool
- class nemo_rl.models.generation.vllm.VllmGeneration(
- cluster: nemo_rl.distributed.virtual_cluster.RayVirtualCluster,
- config: nemo_rl.models.generation.vllm.VllmConfig,
- name_prefix: str = 'vllm_policy',
- workers_per_node: Optional[Union[int, List[int]]] = None,
Bases:
nemo_rl.models.generation.interfaces.GenerationInterface
- _get_tied_worker_bundle_indices(cluster)[source]#
Calculate bundle indices for tensor parallel workers.
- generate(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
- greedy: bool = False,
Generate a batch of data using vLLM.
- generate_text(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
- greedy: bool = False,
Generate text responses using vLLM.
- prepare_for_generation(*args, **kwargs)[source]#
Abstract method that must be implemented by subclasses.
- update_weights(ipc_handles)[source]#
Update weights of the policy using IPC handles, considering tensor parallelism.
For tp > 1, only the leader in each tensor parallel tied worker group will update weights.
- Parameters:
ipc_handles (dict) – Dictionary mapping device UUIDs to parameter IPC handles.
- Returns:
True if weights were successfully updated, False otherwise.
- Return type:
bool