nemo_rl.data_plane.interfaces#

Stable boundary between NeMo-RL and data-plane implementations.

Wire shape adapters must support:

  • fields: TensorDict with tensor leaves AND optional NonTensorStack / NonTensorData leaves (TQ-native non-tensor passthrough). TQ’s storage backends handle encoding per backend (simple keeps Python objects; mooncake_client pickles internally).

  • tags: list[dict[str, Any]] per-sample primitives (kept separate from fields so non-tensor metadata like input_lengths doesn’t pollute the leaf-level schema).

  • keys: per-sample string uids.

  • partition_id: string-named address spaces with declared consumer_tasks and fields schemas.

All call sites in nemo_rl/algorithms, nemo_rl/experience and nemo_rl/models go through :class:DataPlaneClient — never import transfer_queue directly. This is what makes the implementation swappable.

See nemo_rl/data_plane/README.md for the full design.

Module Contents#

Classes#

DataPlaneConfig

Feature-gated config; defaults to disabled.

ObservabilityConfig

Optional middleware that records per-op metrics on the client.

KVBatchMeta

Per-batch metadata for data-plane KV operations.

DataPlaneClient

Stable, swappable data-plane boundary.

API#

class nemo_rl.data_plane.interfaces.DataPlaneConfig#

Bases: typing.TypedDict

Feature-gated config; defaults to disabled.

backend is the storage backend inside TransferQueue; it is owned by the TQ adapter, not by NeMo-RL. impl selects which adapter we go through.

Required keys (always set in exemplar YAML — never defaulted in code): enabled, impl, backend, storage_capacity, num_storage_units, claim_meta_poll_interval_s, global_segment_size, local_buffer_size.

global_segment_size / local_buffer_size are only read when backend == "mooncake_cpu"; the simple backend ignores them. They are required (not NotRequired) so the YAML carries the full schema and there are no hidden Python defaults.

Initialization

Initialize self. See help(type(self)) for accurate signature.

enabled: bool#

None

impl: Literal[transfer_queue]#

None

backend: Literal[simple, mooncake_cpu]#

None

storage_capacity: int#

None

num_storage_units: int#

None

claim_meta_poll_interval_s: float#

None

global_segment_size: int#

None

local_buffer_size: int#

None

controller_address: NotRequired[str]#

None

ack_timeout_ms: NotRequired[int]#

None

observability: NotRequired[ObservabilityConfig]#

None

class nemo_rl.data_plane.interfaces.ObservabilityConfig#

Bases: typing.TypedDict

Optional middleware that records per-op metrics on the client.

Off by default. When enabled=True the factory wraps the chosen adapter with :class:MetricsDataPlaneClient. callback is injected programmatically (callables don’t round-trip through YAML) — set cfg["observability"]["callback"] = my_fn before

Func:

build_data_plane_client to plug into wandb / file / log. Default callback prints one line per op for debug.

Initialization

Initialize self. See help(type(self)) for accurate signature.

enabled: bool#

None

callback: NotRequired[Callable[[dict[str, Any]], None]]#

None

class nemo_rl.data_plane.interfaces.KVBatchMeta#

Per-batch metadata for data-plane KV operations.

Carries the per-sample IDs (sample_ids) that address rows in the KV store plus per-row metadata (fields, sequence_lengths, tags) needed for downstream routing without fetching tensor data. Vocabulary is intentionally NeMo-RL-native rather than 1:1 with any specific backend — the adapter translates at the boundary.

Two roles:

  • Result type returned by :meth:DataPlaneClient.claim_meta — callers extract .sample_ids / .partition_id and pass them to

    meth:

    get_samples / :meth:get_data.

  • Argument type for the per-DP-rank fetch entrypoints. sequence_lengths lets the driver compute a balanced per-rank shard from metadata only (control plane), without ever materializing tensor data.

partition_id: str#

None

task_name: str | None#

None

sample_ids: list[str]#

None

fields: list[str] | None#

None

sequence_lengths: list[int] | None#

None

extra_info: dict[str, Any]#

‘field(…)’

tags: list[dict[str, Any]] | None#

None

__post_init__() None#
property size: int#
stamp_tags(scalars: dict[str, Sequence[Any]]) None#

Mirror per-row scalar columns onto :attr:tags.

Each entry in scalars is a length-size sequence (list, tensor, ndarray) whose elements are written to tags[i][name]. Initializes tags to a list of empty dicts if currently None.

_replace(
*,
sample_ids: list[str],
sequence_lengths: list[int] | None,
tags: list[dict[str, Any]] | None = None,
) nemo_rl.data_plane.interfaces.KVBatchMeta#

Return a copy with new sample_ids/sequence_lengths/tags, same metadata otherwise.

subset(
indices: Sequence[int],
) nemo_rl.data_plane.interfaces.KVBatchMeta#

Return a new meta with only the rows at indices (any order).

slice(
start: int,
stop: int,
) nemo_rl.data_plane.interfaces.KVBatchMeta#

Return a new meta with rows in the contiguous range [start, stop).

concat(
*others: nemo_rl.data_plane.interfaces.KVBatchMeta,
) nemo_rl.data_plane.interfaces.KVBatchMeta#

Append others to self. All metas must share partition_id.

class nemo_rl.data_plane.interfaces.DataPlaneClient#

Bases: abc.ABC

Stable, swappable data-plane boundary.

The methods are split into three groups by intent. Argument order mirrors the underlying transfer_queue API 1:1 so a future adapter (e.g. nv-dataplane) is a thin pass-through too.

A. Task-mediated — used by stages that wait for upstream production via the per-task consumer counter:

Meth:

register_partition, :meth:claim_meta, :meth:get_data,

meth:

check_consumption_status. B. Direct-by-key — used by stages that already know the exact uids (e.g. driver-side fan-out to DP ranks):

meth:

put_samples, :meth:get_samples, :meth:clear_samples. C. Lifecycle — :meth:close.

Stage-completion signal: there is intentionally no mark_consumed. The authoritative signal in TransferQueue is field production — when a stage calls :meth:put_samples for a new field, the controller flips production_status[sample, field] = 1. Downstream consumers waiting on that field only see those samples once produced.

abstractmethod register_partition(
partition_id: str,
fields: list[str],
num_samples: int,
consumer_tasks: list[str],
grpo_group_size: int | None = None,
enums: dict[str, list[str]] | None = None,
) None#

Declare the partition schema and consumer tasks.

Parameters:
  • partition_id – Partition name.

  • fields – Superset of fields any producer may write here.

  • num_samples – Expected total samples; sizes controller arrays.

  • consumer_tasks – Named tasks; each gets its own consumption cursor.

  • grpo_group_size – Group size for GRPO balanced sampling.

  • enums – Per-field fixed-vocab string codec, shipped once at register.

abstractmethod claim_meta(
partition_id: str,
task_name: str,
required_fields: list[str],
batch_size: int,
dp_rank: int | None = None,
blocking: bool = True,
timeout_s: float = 60.0,
) nemo_rl.data_plane.interfaces.KVBatchMeta#

Discover and claim up to batch_size ready samples.

Advances task_name’s per-sample consumption cursor (TQ’s mode='fetch'); claimed uids won’t be returned again. Samples stay readable via :meth:get_samples until :meth:clear_samples.

Parameters:
  • partition_id – Partition to claim from.

  • task_name – Consumer task whose cursor is advanced.

  • required_fields – Fields that must be produced for a sample to be claimable.

  • batch_size – Max samples to claim.

  • dp_rank – Reserved; driver-side balancing via :func:shard_meta_for_dp is used today.

  • blocking – Block until the batch can be claimed.

  • timeout_s – Max blocking time before raising.

Returns:

KVBatchMeta for the claimed batch; pass to :meth:get_data.

abstractmethod get_data(
meta: nemo_rl.data_plane.interfaces.KVBatchMeta,
select_fields: list[str] | None = None,
) tensordict.TensorDict#

Resolve a meta to tensor data.

Field-set resolution: (1) explicit select_fields; (2) meta.fields if non-None; (3) fail loudly — never silently fetch all fields.

Parameters:
  • meta – From :meth:claim_meta or hand-built with explicit keys.

  • select_fields – Subset of fields to fetch.

Returns:

TensorDict keyed by field name, batched along meta.sample_ids.

abstractmethod check_consumption_status(
partition_id: str,
task_names: list[str],
) bool#

True iff every task has consumed all samples in the partition.

Authoritative across workers — uses TQ’s controller-side counter, not the per-process client cache.

Parameters:
  • partition_id – Partition to check.

  • task_names – Tasks whose consumption cursors are inspected.

Returns:

True iff every task in task_names has consumed all samples.

abstractmethod put_samples(
sample_ids: list[str],
partition_id: str,
fields: tensordict.TensorDict | None = None,
tags: list[dict[str, Any]] | None = None,
) nemo_rl.data_plane.interfaces.KVBatchMeta#

Write fields for sample_ids — the producer entrypoint.

Writing a field flips the controller’s production_status bit for (sample, field); that flip is the “stage finished” signal downstream consumers wait on. Tensor and NonTensorStack leaves both pass through to TQ; non-tensor encoding is per-backend.

Parameters:
  • sample_ids – Per-sample uids being written.

  • partition_id – Partition these samples belong to.

  • fields – Tensor / NonTensorStack leaves to write.

  • tags – Optional per-sample primitive metadata.

Returns:

KVBatchMeta covering sample_ids — usable for direct :meth:get_samples.

abstractmethod get_samples(
sample_ids: list[str],
partition_id: str,
select_fields: list[str],
) tensordict.TensorDict#

Direct fetch by uids.

Used by per-DP-rank slice fetches. Does NOT advance any per-task consumption cursor — that only happens via :meth:claim_meta.

select_fields is required (no implicit “fetch every field” fallback): bulk schemas are wide and silent over-fetch is the most expensive shape the wire can take. Callers must name what they read.

Parameters:
  • sample_ids – Uids to fetch.

  • partition_id – Partition the samples live in.

  • select_fields – Subset of fields to fetch.

Returns:

TensorDict keyed by field name, batched along sample_ids.

abstractmethod clear_samples(sample_ids: list[str] | None, partition_id: str) None#

Drop key-value pairs.

Explicit form (sample_ids=[...]) drops exactly those uids and is the form callers should use whenever they have the meta in hand — both sync GRPO callers (driver passes meta.sample_ids) and future async-RL data-loader actors that don’t share a process-local registry with the producer.

Convenience form (sample_ids=None) drops “everything this process knows produced in this partition”. Adapters implement this via a local registry populated by :meth:put_samples, with a fallback query to the underlying store. Useful for step-end teardown when the caller is the producer (driver in sync GRPO). Workers / loader actors that didn’t produce the samples should pass explicit IDs — the None form may silently no-op for them, and adapters are expected to warn when that happens.

Parameters:
  • sample_ids – Uids to drop; None clears every uid this process produced in the partition.

  • partition_id – Partition the samples live in.

abstractmethod close() None#

Release controller / storage handles. Idempotent.