nemo_rl.data_plane.observability#

Lean per-op metrics decorator for DataPlaneClient.

Wraps any DataPlaneClient and invokes a single user-provided callback on each operation. Each event is a flat dict::

{"op", "partition_id", "n_keys", "n_bytes", "wall_ms", "status"}

Plug wandb / file logging / debug print at the call site by passing on_event=<your function>. snapshot() returns cumulative totals plus live memory consumption: bytes_outstanding (sum of bytes currently held in TQ, i.e. put minus cleared) and peak_bytes_outstanding (high-water mark over the run lifetime).

Module Contents#

Classes#

DataPlaneEvent

DataPlaneStats

MetricsDataPlaneClient

Wrap a DataPlaneClient with a per-op callback hook.

Functions#

Data#

API#

nemo_rl.data_plane.observability.EventStatus#

None

class nemo_rl.data_plane.observability.DataPlaneEvent#

Bases: typing.TypedDict

op: str#

None

partition_id: str#

None

n_keys: int#

None

n_bytes: int#

None

wall_ms: float#

None

status: nemo_rl.data_plane.observability.EventStatus#

None

nemo_rl.data_plane.observability.logger#

‘getLogger(…)’

nemo_rl.data_plane.observability._td_bytes(td: tensordict.TensorDict | None) int#
nemo_rl.data_plane.observability.log_event(
event: nemo_rl.data_plane.observability.DataPlaneEvent,
) None#
class nemo_rl.data_plane.observability.DataPlaneStats#
total_bytes: int#

0

total_keys: int#

0

total_ops: int#

0

bytes_outstanding: int#

0

peak_bytes_outstanding: int#

0

max_bytes_per_key_seen: int#

0

last_put_bytes_per_key: int#

0

class nemo_rl.data_plane.observability.MetricsDataPlaneClient(
inner: nemo_rl.data_plane.interfaces.DataPlaneClient,
on_event: Callable[[nemo_rl.data_plane.observability.DataPlaneEvent], None] | None = None,
)#

Bases: nemo_rl.data_plane.interfaces.DataPlaneClient

Wrap a DataPlaneClient with a per-op callback hook.

Initialization

snapshot() dict[str, Any]#

Return cumulative totals plus live byte / key outstanding counts.

bytes_outstanding_by_partition() dict[str, int]#

Per-partition breakdown of currently-held bytes.

_record_put(partition_id: str, keys: list[str], n_bytes: int) None#

Attribute put bytes per key so a later clear_samples can subtract.

Called after the underlying RPC succeeds so a failed put never leaves the accounting inflated.

Parameters:
  • partition_id – Partition the keys were written to.

  • keys – Per-sample uids that were written.

  • n_bytes – Total bytes written; distributed evenly across keys.

_record_clear(partition_id: str, keys: list[str] | None) None#

Reverse the put accounting for keys.

Called after the underlying RPC succeeds so a failed clear keeps the accounting consistent with TQ’s actual state.

Parameters:
  • partition_id – Partition the keys were dropped from.

  • keys – Uids dropped; None means the whole partition was cleared.

_run(
op: str,
partition_id: str,
fn: Callable[[], Any],
*,
n_keys: int = 0,
n_bytes: int = 0,
) Any#

Run fn and emit one observability event with wall-time and status.

Parameters:
  • op – Operation tag ("put", "get", "clear", etc.).

  • partition_id – Partition the op targets.

  • fn – Zero-arg callable that invokes the inner client.

  • n_keys – Key count if known up front; otherwise inferred from the return value (KVBatchMeta.sample_ids).

  • n_bytes – Byte estimate; overridden by _td_bytes when the return is a TensorDict.

Returns:

Whatever fn returned.

_emit(
op: str,
partition_id: str,
n_keys: int,
n_bytes: int,
t0: float,
status: nemo_rl.data_plane.observability.EventStatus,
) None#
register_partition(
partition_id,
fields,
num_samples,
consumer_tasks,
grpo_group_size=None,
enums=None,
)#
claim_meta(
partition_id,
task_name,
required_fields,
batch_size,
dp_rank=None,
blocking=True,
timeout_s=60.0,
)#
get_data(meta, select_fields=None)#
check_consumption_status(partition_id, task_names)#
put_samples(sample_ids, partition_id, fields=None, tags=None)#
get_samples(sample_ids, partition_id, select_fields)#
clear_samples(sample_ids, partition_id)#
close() None#