nemo_rl.models.generation.vllm_backend
#
Module Contents#
Classes#
Functions#
Patch gemma3_mm.py to support new HF multimodal format (post transformers v4.52). |
API#
- nemo_rl.models.generation.vllm_backend._patch_gemma3_mm()#
Patch gemma3_mm.py to support new HF multimodal format (post transformers v4.52).
Patch taken from:https://github.com/vllm-project/vllm/pull/19151/files#diff-5890909300e4e6c3160444e4587ec3fd80498bb83f598b22ce81337f75992b06
- class nemo_rl.models.generation.vllm_backend.VllmInternalWorkerExtension#
- init_collective(
- rank_prefix: int,
- ip: str,
- port: int,
- world_size: int,
Initialize the collective communication.
- report_device_id() str #
- prepare_refit_info(
- state_dict_info: Optional[dict[str, Any]] = None,
Prepare the info for refit.
DtensorPolicyWorker: colocated inference: state_dict_info is None non-colocated inference: state_dict_info is a dict of {tensor_name: (shape, dtype)}
MegatronPolicyWorker: colocated inference: state_dict_info is a dict of {tensor_name: (shape, dtype, numel)} non-colocated inference: not implemented yet
- update_weights_from_global_ipc_handles(global_device_ipc_handles)#
Update weights from global IPC handles.
- Parameters:
global_device_ipc_handles (dict) – Dictionary mapping device UUIDs to parameter IPC handles.
- Returns:
True if weights were successfully updated.
- Return type:
bool
- update_weights_from_local_ipc_handles(local_device_ipc_handles)#
Update weights from local IPC handles.
- Parameters:
local_device_ipc_handles (dict) – parameter IPC handles for local device.
- Returns:
True if weights were successfully updated.
- Return type:
bool
- update_weights_from_collective() bool #
Update the model weights from collective communication.