nemo_rl.models.generation.vllm_backend#

Module Contents#

Classes#

Functions#

_patch_gemma3_mm

Patch gemma3_mm.py to support new HF multimodal format (post transformers v4.52).

API#

nemo_rl.models.generation.vllm_backend._patch_gemma3_mm()#

Patch gemma3_mm.py to support new HF multimodal format (post transformers v4.52).

Patch taken from:https://github.com/vllm-project/vllm/pull/19151/files#diff-5890909300e4e6c3160444e4587ec3fd80498bb83f598b22ce81337f75992b06

class nemo_rl.models.generation.vllm_backend.VllmInternalWorkerExtension#
init_collective(
rank_prefix: int,
ip: str,
port: int,
world_size: int,
) None#

Initialize the collective communication.

report_device_id() str#
prepare_refit_info(
state_dict_info: Optional[dict[str, Any]] = None,
) None#

Prepare the info for refit.

DtensorPolicyWorker: colocated inference: state_dict_info is None non-colocated inference: state_dict_info is a dict of {tensor_name: (shape, dtype)}

MegatronPolicyWorker: colocated inference: state_dict_info is a dict of {tensor_name: (shape, dtype, numel)} non-colocated inference: not implemented yet

update_weights_from_global_ipc_handles(global_device_ipc_handles)#

Update weights from global IPC handles.

Parameters:

global_device_ipc_handles (dict) – Dictionary mapping device UUIDs to parameter IPC handles.

Returns:

True if weights were successfully updated.

Return type:

bool

update_weights_from_local_ipc_handles(local_device_ipc_handles)#

Update weights from local IPC handles.

Parameters:

local_device_ipc_handles (dict) – parameter IPC handles for local device.

Returns:

True if weights were successfully updated.

Return type:

bool

update_weights_from_collective() bool#

Update the model weights from collective communication.