nemo_rl.weight_sync.interfaces#

Weight synchronization interface for NeMo-RL.

WeightSynchronizer is a dedicated abstraction that decouples weight transfer logic from both PolicyInterface and GenerationInterface. It owns the transfer of model weights between training and generation components.

Transport-specific implementations (IPC/ZMQ, HTTP, NCCL collectives) each encapsulate the transfer lifecycle, so algorithm code never branches on backend type.

Colocated transports (IPC, HTTP) own GPU phase transitions internally (offload, prepare_for_generation, restore) as part of their sync_weights() implementation. The NCCL collective transport is a pure data mover; the orchestrator handles phase transitions externally since policy and generation run on separate GPU clusters.

This interface assumes global weight updates: all generation workers are updated atomically and are always at the same weight version. Per-worker updates (where different replicas could be at different versions) are not supported. In async GRPO, heterogeneous weight ages are handled at the sample level (via replay buffer target_weight_versions tracking), not at the synchronizer level.

Module Contents#

Classes#

WeightSynchronizer

Abstract base class for weight synchronization between policy and generation.

API#

class nemo_rl.weight_sync.interfaces.WeightSynchronizer#

Bases: abc.ABC

Abstract base class for weight synchronization between policy and generation.

Implementations handle the weight transfer for a specific transport mechanism (ZMQ IPC, HTTP, NCCL collectives). The orchestrator calls sync_weights() and mark_stale() without knowing which transport is being used or whether components are colocated.

Colocated transports (IPC, HTTP) own phase transitions internally (offload_before_refit, prepare_for_generation, offload_after_refit). The NCCL collective transport is a pure data mover; the orchestrator handles phases externally.

abstractmethod sync_weights(
*,
timer: Optional[nemo_rl.utils.timer.Timer] = None,
kv_scales: Optional[dict[str, float]] = None,
) None#

Transfer the latest policy weights to the generation backend.

This method encapsulates the full sync lifecycle:

  1. Prepare the policy side (e.g., offload optimizer state to free GPU memory)

  2. Prepare the generation side (e.g., allocate weight buffers)

  3. Transfer weights via the transport mechanism

  4. Verify the transfer succeeded

  5. Restore both sides to their ready state

Steps 1-2 and 5 (phase transitions) are only performed by colocated transports (IPC, HTTP). The NCCL collective transport skips them since policy and generation run on separate GPUs.

Step 4 (verification) is performed explicitly by IPC and NCCL transports, which check update_success and raise on failure. The HTTP transport relies on ray.get() to propagate any server-side errors (matching the existing grpo.py behavior).

Parameters:
  • timer – Optional Timer for profiling individual phases.

  • kv_scales – Optional KV cache scales for FP8 quantization. Note: Only honored by the NCCL collective transport, which forwards them to policy.broadcast_weights_for_collective(). IPC and HTTP transports ignore this parameter.

Raises:

RuntimeError – If the weight transfer fails.

abstract property is_stale: bool#

Whether the generation backend’s weights are out of date.

Returns True after mark_stale() is called and before the next successful sync_weights() completes.

abstractmethod mark_stale() None#

Mark weights as stale after a training step.

Should be called after every training step so the orchestrator knows a sync is needed before the next generation phase. This applies globally — all generation workers are considered stale and will be updated atomically on the next sync_weights() call.

abstractmethod init_communicator() None#

Initialize any communication channels needed for weight transfer.

Called once during setup, after policy and generation workers are constructed. For colocated IPC/HTTP transports this may prepare refit metadata. For NCCL collectives this initializes the process group.

abstractmethod shutdown() None#

Release all communication resources.