nemo_rl.data_plane.interfaces#
Stable boundary between NeMo-RL and data-plane implementations.
Wire shape adapters must support:
fields:TensorDictwith tensor leaves AND optionalNonTensorStack/NonTensorDataleaves (TQ-native non-tensor passthrough). TQ’s storage backends handle encoding per backend (simple keeps Python objects; mooncake_client pickles internally).tags:list[dict[str, Any]]per-sample primitives (kept separate fromfieldsso non-tensor metadata likeinput_lengthsdoesn’t pollute the leaf-level schema).keys: per-sample string uids.partition_id: string-named address spaces with declaredconsumer_tasksandfieldsschemas.
All call sites in nemo_rl/algorithms, nemo_rl/experience and
nemo_rl/models go through :class:DataPlaneClient — never
import transfer_queue directly. This is what makes the
implementation swappable.
See nemo_rl/data_plane/README.md for the full design.
Module Contents#
Classes#
Feature-gated config; defaults to disabled. |
|
Optional middleware that records per-op metrics on the client. |
|
Per-batch metadata for data-plane KV operations. |
|
Stable, swappable data-plane boundary. |
API#
- class nemo_rl.data_plane.interfaces.DataPlaneConfig#
Bases:
typing.TypedDictFeature-gated config; defaults to disabled.
backendis the storage backend inside TransferQueue; it is owned by the TQ adapter, not by NeMo-RL.implselects which adapter we go through.Required keys (always set in exemplar YAML — never defaulted in code):
enabled,impl,backend,storage_capacity,num_storage_units,claim_meta_poll_interval_s,global_segment_size,local_buffer_size.global_segment_size/local_buffer_sizeare only read whenbackend == "mooncake_cpu"; the simple backend ignores them. They are required (not NotRequired) so the YAML carries the full schema and there are no hidden Python defaults.Initialization
Initialize self. See help(type(self)) for accurate signature.
- enabled: bool#
None
- impl: Literal[transfer_queue]#
None
- backend: Literal[simple, mooncake_cpu]#
None
- storage_capacity: int#
None
- num_storage_units: int#
None
- claim_meta_poll_interval_s: float#
None
- global_segment_size: int#
None
- local_buffer_size: int#
None
- controller_address: NotRequired[str]#
None
- ack_timeout_ms: NotRequired[int]#
None
- observability: NotRequired[ObservabilityConfig]#
None
- class nemo_rl.data_plane.interfaces.ObservabilityConfig#
Bases:
typing.TypedDictOptional middleware that records per-op metrics on the client.
Off by default. When
enabled=Truethe factory wraps the chosen adapter with :class:MetricsDataPlaneClient.callbackis injected programmatically (callables don’t round-trip through YAML) — setcfg["observability"]["callback"] = my_fnbefore- Func:
build_data_plane_clientto plug into wandb / file / log. Default callback prints one line per op for debug.
Initialization
Initialize self. See help(type(self)) for accurate signature.
- enabled: bool#
None
- callback: NotRequired[Callable[[dict[str, Any]], None]]#
None
- class nemo_rl.data_plane.interfaces.KVBatchMeta#
Per-batch metadata for data-plane KV operations.
Carries the per-sample IDs (
sample_ids) that address rows in the KV store plus per-row metadata (fields,sequence_lengths,tags) needed for downstream routing without fetching tensor data. Vocabulary is intentionally NeMo-RL-native rather than 1:1 with any specific backend — the adapter translates at the boundary.Two roles:
Result type returned by :meth:
DataPlaneClient.claim_meta— callers extract.sample_ids/.partition_idand pass them to- meth:
get_samples/ :meth:get_data.
Argument type for the per-DP-rank fetch entrypoints.
sequence_lengthslets the driver compute a balanced per-rank shard from metadata only (control plane), without ever materializing tensor data.
- partition_id: str#
None
- task_name: str | None#
None
- sample_ids: list[str]#
None
- fields: list[str] | None#
None
- sequence_lengths: list[int] | None#
None
- extra_info: dict[str, Any]#
‘field(…)’
- tags: list[dict[str, Any]] | None#
None
- __post_init__() None#
- property size: int#
- stamp_tags(scalars: dict[str, Sequence[Any]]) None#
Mirror per-row scalar columns onto :attr:
tags.Each entry in
scalarsis a length-sizesequence (list, tensor, ndarray) whose elements are written totags[i][name]. Initializestagsto a list of empty dicts if currently None.
- _replace(
- *,
- sample_ids: list[str],
- sequence_lengths: list[int] | None,
- tags: list[dict[str, Any]] | None = None,
Return a copy with new sample_ids/sequence_lengths/tags, same metadata otherwise.
- subset(
- indices: Sequence[int],
Return a new meta with only the rows at
indices(any order).
- slice(
- start: int,
- stop: int,
Return a new meta with rows in the contiguous range
[start, stop).
- concat( ) nemo_rl.data_plane.interfaces.KVBatchMeta#
Append
otherstoself. All metas must sharepartition_id.
- class nemo_rl.data_plane.interfaces.DataPlaneClient#
Bases:
abc.ABCStable, swappable data-plane boundary.
The methods are split into three groups by intent. Argument order mirrors the underlying
transfer_queueAPI 1:1 so a future adapter (e.g.nv-dataplane) is a thin pass-through too.A. Task-mediated — used by stages that wait for upstream production via the per-task consumer counter:
- Meth:
register_partition, :meth:claim_meta, :meth:get_data,- meth:
check_consumption_status. B. Direct-by-key — used by stages that already know the exact uids (e.g. driver-side fan-out to DP ranks):- meth:
put_samples, :meth:get_samples, :meth:clear_samples. C. Lifecycle — :meth:close.
Stage-completion signal: there is intentionally no
mark_consumed. The authoritative signal in TransferQueue is field production — when a stage calls :meth:put_samplesfor a new field, the controller flipsproduction_status[sample, field] = 1. Downstream consumers waiting on that field only see those samples once produced.- abstractmethod register_partition(
- partition_id: str,
- fields: list[str],
- num_samples: int,
- consumer_tasks: list[str],
- grpo_group_size: int | None = None,
- enums: dict[str, list[str]] | None = None,
Declare the partition schema and consumer tasks.
- Parameters:
partition_id – Partition name.
fields – Superset of fields any producer may write here.
num_samples – Expected total samples; sizes controller arrays.
consumer_tasks – Named tasks; each gets its own consumption cursor.
grpo_group_size – Group size for GRPO balanced sampling.
enums – Per-field fixed-vocab string codec, shipped once at register.
- abstractmethod claim_meta(
- partition_id: str,
- task_name: str,
- required_fields: list[str],
- batch_size: int,
- dp_rank: int | None = None,
- blocking: bool = True,
- timeout_s: float = 60.0,
Discover and claim up to
batch_sizeready samples.Advances
task_name’s per-sample consumption cursor (TQ’smode='fetch'); claimed uids won’t be returned again. Samples stay readable via :meth:get_samplesuntil :meth:clear_samples.- Parameters:
partition_id – Partition to claim from.
task_name – Consumer task whose cursor is advanced.
required_fields – Fields that must be produced for a sample to be claimable.
batch_size – Max samples to claim.
dp_rank – Reserved; driver-side balancing via :func:
shard_meta_for_dpis used today.blocking – Block until the batch can be claimed.
timeout_s – Max blocking time before raising.
- Returns:
KVBatchMetafor the claimed batch; pass to :meth:get_data.
- abstractmethod get_data(
- meta: nemo_rl.data_plane.interfaces.KVBatchMeta,
- select_fields: list[str] | None = None,
Resolve a meta to tensor data.
Field-set resolution: (1) explicit
select_fields; (2)meta.fieldsif non-None; (3) fail loudly — never silently fetch all fields.- Parameters:
meta – From :meth:
claim_metaor hand-built with explicit keys.select_fields – Subset of fields to fetch.
- Returns:
TensorDictkeyed by field name, batched alongmeta.sample_ids.
- abstractmethod check_consumption_status(
- partition_id: str,
- task_names: list[str],
True iff every task has consumed all samples in the partition.
Authoritative across workers — uses TQ’s controller-side counter, not the per-process client cache.
- Parameters:
partition_id – Partition to check.
task_names – Tasks whose consumption cursors are inspected.
- Returns:
Trueiff every task intask_nameshas consumed all samples.
- abstractmethod put_samples(
- sample_ids: list[str],
- partition_id: str,
- fields: tensordict.TensorDict | None = None,
- tags: list[dict[str, Any]] | None = None,
Write fields for
sample_ids— the producer entrypoint.Writing a field flips the controller’s
production_statusbit for(sample, field); that flip is the “stage finished” signal downstream consumers wait on. Tensor andNonTensorStackleaves both pass through to TQ; non-tensor encoding is per-backend.- Parameters:
sample_ids – Per-sample uids being written.
partition_id – Partition these samples belong to.
fields – Tensor /
NonTensorStackleaves to write.tags – Optional per-sample primitive metadata.
- Returns:
KVBatchMetacoveringsample_ids— usable for direct :meth:get_samples.
- abstractmethod get_samples(
- sample_ids: list[str],
- partition_id: str,
- select_fields: list[str],
Direct fetch by uids.
Used by per-DP-rank slice fetches. Does NOT advance any per-task consumption cursor — that only happens via :meth:
claim_meta.select_fieldsis required (no implicit “fetch every field” fallback): bulk schemas are wide and silent over-fetch is the most expensive shape the wire can take. Callers must name what they read.- Parameters:
sample_ids – Uids to fetch.
partition_id – Partition the samples live in.
select_fields – Subset of fields to fetch.
- Returns:
TensorDictkeyed by field name, batched alongsample_ids.
- abstractmethod clear_samples(sample_ids: list[str] | None, partition_id: str) None#
Drop key-value pairs.
Explicit form (
sample_ids=[...]) drops exactly those uids and is the form callers should use whenever they have the meta in hand — both sync GRPO callers (driver passesmeta.sample_ids) and future async-RL data-loader actors that don’t share a process-local registry with the producer.Convenience form (
sample_ids=None) drops “everything this process knows produced in this partition”. Adapters implement this via a local registry populated by :meth:put_samples, with a fallback query to the underlying store. Useful for step-end teardown when the caller is the producer (driver in sync GRPO). Workers / loader actors that didn’t produce the samples should pass explicit IDs — theNoneform may silently no-op for them, and adapters are expected to warn when that happens.- Parameters:
sample_ids – Uids to drop;
Noneclears every uid this process produced in the partition.partition_id – Partition the samples live in.
- abstractmethod close() None#
Release controller / storage handles. Idempotent.