nemo_rl.data_plane.worker_mixin#
TransferQueue awareness for policy workers, isolated from the base class.
Mix into a worker class to add per-rank TQ-mediated entrypoints
(:meth:train_presharded, :meth:get_logprobs_presharded,
- meth:
get_reference_policy_logprobs_presharded) without touchingBasePolicyWorker. Subclasses that don’t need TQ keep their bare inheritance and stay zero-cost.
Subclasses must implement :meth:_get_replica_group (returns the
NCCL group of TP×CP×PP siblings within this DP rank, or None for
TP=CP=PP=1) and inherit train / get_logprobs /
get_reference_policy_logprobs from the worker base.
Module Contents#
Classes#
Adds TransferQueue per-rank fetch/write-back to a policy worker. |
Functions#
Broadcast a BatchedDataDict from |
Data#
API#
- nemo_rl.data_plane.worker_mixin.FetchPolicy#
None
- nemo_rl.data_plane.worker_mixin._broadcast_batched_data_dict(
- data: Optional[nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any]],
- *,
- is_leader: bool,
- src: int,
- group: Any,
Broadcast a BatchedDataDict from
srcto all ranks ingroup.Two-phase to avoid pickling tensor payloads on the hot path: a small descriptor (per-key dtype/shape) ships via
broadcast_object_listfirst, then each tensor’s data ships viabroadcaston its current device. The leader suppliesdata; non-leaders passNoneand get an empty BatchedDataDict filled in-place.
- class nemo_rl.data_plane.worker_mixin.TQWorkerMixin#
Adds TransferQueue per-rank fetch/write-back to a policy worker.
The driver-side
TQPolicyfans out per-rankKVBatchMeta; each worker callsself._fetch(meta, ...)to pull its slice from TQ and runs the existing per-rank method body.- _dp_client: Optional[nemo_rl.data_plane.interfaces.DataPlaneClient]#
None
- setup_data_plane(cfg: nemo_rl.data_plane.DataPlaneConfig) None#
Connect this worker process’s client to the existing TQ controller.
Called once by the driver after worker construction. Idempotent.
- _require_dp_client() nemo_rl.data_plane.interfaces.DataPlaneClient#
- _get_replica_group() Optional[Any]#
NCCL group of TP×CP×PP siblings within this DP rank.
Nonemeans “no siblings” (TP=CP=PP=1). Subclasses must override using their parallelism state (DTensordevice_mesh, Megatronparallel_state). ReturningNonemakes- Meth:
_fetchuse independent fetch; returning a group makes it use leader-fetch + NCCL broadcast.
- _pad_value_dict() dict[str, Any]#
Per-field pad value used by :func:
materializeto detile the jagged wire format.Token-id fields use the tokenizer pad id.
- _forward_pad_seqlen(meta: nemo_rl.data_plane.KVBatchMeta) int#
Cross-DP forward pad target, minted by :meth:
TQPolicy._stamp_pad_seqlen.
- _fetch(
- meta: nemo_rl.data_plane.KVBatchMeta,
- *,
- layout: nemo_rl.data_plane.schema.Layout = 'padded',
- fetch_policy: nemo_rl.data_plane.worker_mixin.FetchPolicy = 'auto',
- preprocess: Optional[Any] = None,
- dp_aligned_seq_len: bool = True,
Fetch this rank’s slice from TQ and return a BatchedDataDict.
- Parameters:
meta –
Per-rank
KVBatchMetafrom :func:shard_meta_for_dp. Forward-pass pad target is read frommeta.extra_info[GLOBAL_FORWARD_PAD_SEQLEN]minted by- meth:
TQPolicy._stamp_pad_seqlen.
layout – Materialization layout (
"padded"or"jagged").fetch_policy –
"auto"uses leader-fetch + NCCL broadcast when- meth:
_get_replica_groupreturns a group, else independent fetch (cheapest for TP=CP=PP=1)."independent"forces every sibling to fetch."leader_broadcast"forces the broadcast path and asserts a replica group exists.
preprocess – Optional
(worker, td) -> tdapplied between materialize and return.dp_aligned_seq_len – When True (default), right-pad the seq dim for the forward pass. Disabled in tests that want to observe per-rank local-pad behavior.
- Returns:
BatchedDataDictof this rank’s slice.
- _apply_packing_prep( ) nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any]#
Re-derive
micro_batch_indices/micro_batch_lengthson the local slice.Uses
shard_by_batch_size(shards=1, ...). The legacy DP path computes those as a side effect of the DP-shard call; the TQ presharded path receives a per-rank slice without them set, so we recompute here usingself.cfg.
- _attach_or_repack_pack_metadata(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
- meta: nemo_rl.data_plane.KVBatchMeta,
Trust driver-supplied packing metadata or re-derive locally.
When the driver pre-balanced packing across DP ranks it ships
micro_batch_indices/micro_batch_lengths(and optionallyelem_counts_per_gb) inmeta.extra_info. Locally re-packing produces variable bin counts across DP groups and desyncs Megatron’s per-microbatch collectives — trust the driver when it provided the metadata.
- abstractmethod _local_coords() dict[str, int]#
This worker’s (axis -> local-rank) mapping.
Subclasses MUST override: DTensor reads
device_mesh, Megatron readsparallel_state. There’s no honest default — a missing impl would silently make every rank a writeback leader and re-create the-601 ILLEGAL_CLIENTduplicate-write bug.
- _is_replica_leader() bool#
True iff this rank should perform per-DP-rank-unique side-effects.
Examples include TQ write-back. Shares the same predicate the driver uses to gate dispatch (:meth:
NamedSharding.is_axis_zero) — fed by per-worker :meth:_local_coordsinstead ofNamedSharding.get_worker_coords; same answer either way.
- _write_back(
- meta: nemo_rl.data_plane.KVBatchMeta,
- fields: dict[str, torch.Tensor],
Leader-only
put_samples(meta.sample_ids, fields=...).Per-token fields are jagged-packed via :func:
maybe_pack_jaggedso they land with the same row lengths as the initial put; without this a worker write-back (rectangular[N, S]) would mismatch the jaggedinput_idson the next read.- Parameters:
meta – Per-rank
KVBatchMetafor this slice.fields – Map of field name to tensor to write back.
- _write_back_result_field(
- meta: nemo_rl.data_plane.KVBatchMeta,
- result: Any,
- *,
- result_key: str,
- tq_field: str,
Single chokepoint for
*_preshardedwrite-backs.resultis checked via theMappingABC becauseBatchedDataDictis aUserDict(notdict).- Parameters:
meta – Per-rank
KVBatchMetafor this slice.result – Worker output containing
result_key.result_key – Key into
resultfor the tensor to write back.tq_field – Field name on the TQ side.
- train_presharded(
- meta: nemo_rl.data_plane.KVBatchMeta,
- loss_fn: Any,
- eval_mode: bool = False,
- gbs: Optional[int] = None,
- mbs: Optional[int] = None,
Per-rank training entrypoint. Fetch → packing prep → delegate.
- get_logprobs_presharded(
- meta: nemo_rl.data_plane.KVBatchMeta,
- micro_batch_size: Optional[int] = None,
Per-rank logprob entrypoint. Fetch → packing prep → run → write back.
Returns
None— the per-token tensor is committed to TQ via- Meth:
_write_back_result_fieldunderprev_logprobs. Callers fetch it through :meth:TQPolicy.read_from_dataplane— skipping the Ray plasma roundtrip on the (B, S) tensor.del resultdrops the local reference before returning so the worker doesn’t carry the tensor into the next dispatch.
- get_reference_policy_logprobs_presharded(
- meta: nemo_rl.data_plane.KVBatchMeta,
- micro_batch_size: Optional[int] = None,
Per-rank reference-policy logprob entrypoint.
See :meth:
get_logprobs_preshardedfor the contract. Tensor lives in TQ underreference_policy_logprobs.