nemo_rl.data_plane.worker_mixin#

TransferQueue awareness for policy workers, isolated from the base class.

Mix into a worker class to add per-rank TQ-mediated entrypoints (:meth:train_presharded, :meth:get_logprobs_presharded,

meth:

get_reference_policy_logprobs_presharded) without touching BasePolicyWorker. Subclasses that don’t need TQ keep their bare inheritance and stay zero-cost.

Subclasses must implement :meth:_get_replica_group (returns the NCCL group of TP×CP×PP siblings within this DP rank, or None for TP=CP=PP=1) and inherit train / get_logprobs / get_reference_policy_logprobs from the worker base.

Module Contents#

Classes#

TQWorkerMixin

Adds TransferQueue per-rank fetch/write-back to a policy worker.

Functions#

_broadcast_batched_data_dict

Broadcast a BatchedDataDict from src to all ranks in group.

Data#

API#

nemo_rl.data_plane.worker_mixin.FetchPolicy#

None

nemo_rl.data_plane.worker_mixin._broadcast_batched_data_dict(
data: Optional[nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any]],
*,
is_leader: bool,
src: int,
group: Any,
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any]#

Broadcast a BatchedDataDict from src to all ranks in group.

Two-phase to avoid pickling tensor payloads on the hot path: a small descriptor (per-key dtype/shape) ships via broadcast_object_list first, then each tensor’s data ships via broadcast on its current device. The leader supplies data; non-leaders pass None and get an empty BatchedDataDict filled in-place.

class nemo_rl.data_plane.worker_mixin.TQWorkerMixin#

Adds TransferQueue per-rank fetch/write-back to a policy worker.

The driver-side TQPolicy fans out per-rank KVBatchMeta; each worker calls self._fetch(meta, ...) to pull its slice from TQ and runs the existing per-rank method body.

_dp_client: Optional[nemo_rl.data_plane.interfaces.DataPlaneClient]#

None

setup_data_plane(cfg: nemo_rl.data_plane.DataPlaneConfig) None#

Connect this worker process’s client to the existing TQ controller.

Called once by the driver after worker construction. Idempotent.

_require_dp_client() nemo_rl.data_plane.interfaces.DataPlaneClient#
_get_replica_group() Optional[Any]#

NCCL group of TP×CP×PP siblings within this DP rank.

None means “no siblings” (TP=CP=PP=1). Subclasses must override using their parallelism state (DTensor device_mesh, Megatron parallel_state). Returning None makes

Meth:

_fetch use independent fetch; returning a group makes it use leader-fetch + NCCL broadcast.

_pad_value_dict() dict[str, Any]#

Per-field pad value used by :func:materialize to detile the jagged wire format.

Token-id fields use the tokenizer pad id.

_forward_pad_seqlen(meta: nemo_rl.data_plane.KVBatchMeta) int#

Cross-DP forward pad target, minted by :meth:TQPolicy._stamp_pad_seqlen.

_fetch(
meta: nemo_rl.data_plane.KVBatchMeta,
*,
layout: nemo_rl.data_plane.schema.Layout = 'padded',
fetch_policy: nemo_rl.data_plane.worker_mixin.FetchPolicy = 'auto',
preprocess: Optional[Any] = None,
dp_aligned_seq_len: bool = True,
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any]#

Fetch this rank’s slice from TQ and return a BatchedDataDict.

Parameters:
  • meta

    Per-rank KVBatchMeta from :func:shard_meta_for_dp. Forward-pass pad target is read from meta.extra_info[GLOBAL_FORWARD_PAD_SEQLEN] minted by

    meth:

    TQPolicy._stamp_pad_seqlen.

  • layout – Materialization layout ("padded" or "jagged").

  • fetch_policy

    "auto" uses leader-fetch + NCCL broadcast when

    meth:

    _get_replica_group returns a group, else independent fetch (cheapest for TP=CP=PP=1). "independent" forces every sibling to fetch. "leader_broadcast" forces the broadcast path and asserts a replica group exists.

  • preprocess – Optional (worker, td) -> td applied between materialize and return.

  • dp_aligned_seq_len – When True (default), right-pad the seq dim for the forward pass. Disabled in tests that want to observe per-rank local-pad behavior.

Returns:

BatchedDataDict of this rank’s slice.

_apply_packing_prep(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any]#

Re-derive micro_batch_indices / micro_batch_lengths on the local slice.

Uses shard_by_batch_size(shards=1, ...). The legacy DP path computes those as a side effect of the DP-shard call; the TQ presharded path receives a per-rank slice without them set, so we recompute here using self.cfg.

_attach_or_repack_pack_metadata(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
meta: nemo_rl.data_plane.KVBatchMeta,
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any]#

Trust driver-supplied packing metadata or re-derive locally.

When the driver pre-balanced packing across DP ranks it ships micro_batch_indices / micro_batch_lengths (and optionally elem_counts_per_gb) in meta.extra_info. Locally re-packing produces variable bin counts across DP groups and desyncs Megatron’s per-microbatch collectives — trust the driver when it provided the metadata.

abstractmethod _local_coords() dict[str, int]#

This worker’s (axis -> local-rank) mapping.

Subclasses MUST override: DTensor reads device_mesh, Megatron reads parallel_state. There’s no honest default — a missing impl would silently make every rank a writeback leader and re-create the -601 ILLEGAL_CLIENT duplicate-write bug.

_is_replica_leader() bool#

True iff this rank should perform per-DP-rank-unique side-effects.

Examples include TQ write-back. Shares the same predicate the driver uses to gate dispatch (:meth:NamedSharding.is_axis_zero) — fed by per-worker :meth:_local_coords instead of NamedSharding.get_worker_coords; same answer either way.

_write_back(
meta: nemo_rl.data_plane.KVBatchMeta,
fields: dict[str, torch.Tensor],
) None#

Leader-only put_samples(meta.sample_ids, fields=...).

Per-token fields are jagged-packed via :func:maybe_pack_jagged so they land with the same row lengths as the initial put; without this a worker write-back (rectangular [N, S]) would mismatch the jagged input_ids on the next read.

Parameters:
  • meta – Per-rank KVBatchMeta for this slice.

  • fields – Map of field name to tensor to write back.

_write_back_result_field(
meta: nemo_rl.data_plane.KVBatchMeta,
result: Any,
*,
result_key: str,
tq_field: str,
) None#

Single chokepoint for *_presharded write-backs.

result is checked via the Mapping ABC because BatchedDataDict is a UserDict (not dict).

Parameters:
  • meta – Per-rank KVBatchMeta for this slice.

  • result – Worker output containing result_key.

  • result_key – Key into result for the tensor to write back.

  • tq_field – Field name on the TQ side.

train_presharded(
meta: nemo_rl.data_plane.KVBatchMeta,
loss_fn: Any,
eval_mode: bool = False,
gbs: Optional[int] = None,
mbs: Optional[int] = None,
) dict[str, Any]#

Per-rank training entrypoint. Fetch → packing prep → delegate.

get_logprobs_presharded(
meta: nemo_rl.data_plane.KVBatchMeta,
micro_batch_size: Optional[int] = None,
) None#

Per-rank logprob entrypoint. Fetch → packing prep → run → write back.

Returns None — the per-token tensor is committed to TQ via

Meth:

_write_back_result_field under prev_logprobs. Callers fetch it through :meth:TQPolicy.read_from_dataplane — skipping the Ray plasma roundtrip on the (B, S) tensor. del result drops the local reference before returning so the worker doesn’t carry the tensor into the next dispatch.

get_reference_policy_logprobs_presharded(
meta: nemo_rl.data_plane.KVBatchMeta,
micro_batch_size: Optional[int] = None,
) None#

Per-rank reference-policy logprob entrypoint.

See :meth:get_logprobs_presharded for the contract. Tensor lives in TQ under reference_policy_logprobs.