nemo_rl.data_plane.codec#

Wire <-> trainer codec — jagged-on-the-wire bridge.

  • Writer side: variable-length fields are encoded as torch.nested.nested_tensor with layout=torch.jagged before put_samples. Padding tax is paid only when a consumer needs a rectangular tensor.

  • Reader side: :func:materialize accepts the wire TensorDict and, when layout='padded', calls

func:

torch.nested.to_padded_tensor on any nested leaves using the per-field padding value supplied in pad_value_dict. Trainer code consumes the padded BatchedDataDict unchanged.

  • Worker write-backs that produce response-shaped outputs use

func:

response_from_nested to extract the response slice from a (prompt+response) nested tensor.

  • Non-tensor object fields ride as NonTensorStack / NonTensorData leaves (TQ-native passthrough). :func:materialize decodes them back to np.ndarray(dtype=object) for the trainer.

Module Contents#

Functions#

to_nested_by_length

Strip right-padding off a rectangular tensor using per-row lengths.

stack_or_nest

Stack equal-shape rows; reconstruct as jagged nested when ragged.

unwrap_wire_stripped_payload

Recover the payload of a possibly wire-stripped NonTensorData.

maybe_pack_jagged

Convert val to jagged iff it looks like a per-token field.

pack_jagged_fields

Pack a column dict into the wire layout expected by put_samples.

pack_per_token_field

Force-jaggedize a known per-token field, tolerating SP padding.

response_from_nested

Extract the response slice from a (prompt+response) nested tensor.

materialize

Convert a wire TensorDict to a BatchedDataDict.

API#

nemo_rl.data_plane.codec.to_nested_by_length(
padded: torch.Tensor,
lengths: torch.Tensor,
) torch.Tensor#

Strip right-padding off a rectangular tensor using per-row lengths.

Used by the producer side: convert

Func:

batched_message_log_to_flat_message output (already padded) into the wire format before put_samples.

Parameters:
  • padded – Rectangular tensor of shape (N, S, ...).

  • lengths – Per-row valid lengths, shape (N,). CUDA tensors are moved to CPU once to avoid per-row syncs.

Returns:

A torch.jagged nested tensor whose i-th row is padded[i, :lengths[i], ...].

nemo_rl.data_plane.codec.stack_or_nest(tensors: list[torch.Tensor]) torch.Tensor#

Stack equal-shape rows; reconstruct as jagged nested when ragged.

Parameters:

tensors – Per-row tensors; assumed to share leading dims modulo an optional ragged seq dim. Empty list returns torch.empty(0).

Returns:

A regular tensor when all rows share shape; otherwise a torch.jagged nested tensor.

nemo_rl.data_plane.codec.unwrap_wire_stripped_payload(item: Any) Any#

Recover the payload of a possibly wire-stripped NonTensorData.

TQ’s MsgpackEncoder._encode_tensordict serializes any TensorDictBase via dict(obj.items()) — only the tensor backing dict. NonTensorData stores its payload in _non_tensordict["data"], so it round-trips through ZMQ as an empty TensorDict({}, batch_size=[]). We map only that exact signature to None; any other TensorDictBase (with tensor fields, non-scalar batch, or a salvageable _non_tensordict payload) passes through unchanged so we never drop real data.

nemo_rl.data_plane.codec.maybe_pack_jagged(
val: torch.Tensor,
lengths: torch.Tensor,
) torch.Tensor#

Convert val to jagged iff it looks like a per-token field.

Used by every write site (initial put, driver delta-write, worker write-back) so all per-token fields land in TQ as jagged with the same row lengths — read-time materialization then pads them all to the same target shape, avoiding shape-mismatch crashes between mixed wire formats.

Parameters:
  • val – Tensor to consider. Qualifies for jagged conversion only when val.shape == (N, max(lengths), ...) where N == lengths.shape[0].

  • lengths – Per-row valid lengths, shape (N,).

Returns:

A torch.jagged nested tensor when the shape heuristic matches; otherwise val passed through as a rectangular tensor.

nemo_rl.data_plane.codec.pack_jagged_fields(
fields: dict[str, torch.Tensor | np.ndarray],
*,
lengths: torch.Tensor | None,
) tensordict.TensorDict#

Pack a column dict into the wire layout expected by put_samples.

Zero-copy where possible: per-token tensors that match (N, max(lengths), ...) become torch.jagged views via

Func:

maybe_pack_jagged; non-conforming tensors pass through rectangular; np.ndarray(dtype=object) is forwarded as-is. This is a layout transform, not serialization — the on-wire bytes are produced later by the TQ backend’s msgpack encoder. Centralizing the transform here makes it the single source of truth for both

Func:

kv_first_write and :func:write_columns.

Parameters:
  • fields – Column name → tensor or object array. Other value types raise TypeError.

  • lengths – Per-row valid lengths used by :func:maybe_pack_jagged to decide whether a tensor qualifies for jagged conversion. None disables jagged conversion entirely (every tensor passes through rectangular).

Returns:

TensorDict with batch_size=[N] (N from lengths if given, else 0) ready for put_samples.

nemo_rl.data_plane.codec.pack_per_token_field(
val: torch.Tensor,
lengths: torch.Tensor,
) torch.Tensor#

Force-jaggedize a known per-token field, tolerating SP padding.

Unlike :func:maybe_pack_jagged (which is shape-strict to avoid false positives on 3D extras like image features), this function is invoked at write-back sites where the caller already knows the field is per-token (e.g. prev_logprobs, reference_policy_logprobs). mcore SP rounds the forward output’s seq dim up to a multiple of TP, so the value can be 1+ tokens wider than max(lengths); :func:to_nested_by_length slices each row to its own length and drops the trailing SP padding cleanly.

Parameters:
  • val – Per-token tensor. Falls back to rectangular when it cannot be jaggedized (wrong batch dim, < 2D, or seq dim shorter than max(lengths)).

  • lengths – Per-row valid lengths, shape (N,).

Returns:

A torch.jagged nested tensor when the shape allows; otherwise val passed through as a rectangular tensor.

nemo_rl.data_plane.codec.response_from_nested(
full: torch.Tensor,
response_mask: torch.Tensor,
) torch.Tensor#

Extract the response slice from a (prompt+response) nested tensor.

Used on the worker side for logprob / ref-logprob write-back where only the response-token slice is interesting downstream. The “left-shift by one token” convention is applied (so logprobs at output position i correspond to the prediction of input token i+1).

Parameters:
  • full – Jagged nested tensor of shape (N, prompt_len + response_len).

  • response_mask – Jagged nested tensor of shape (N, response_len); its offsets().diff() gives the per-row response length.

Returns:

Jagged nested tensor of shape (N, response_len) containing the left-shifted response slice.

nemo_rl.data_plane.codec.materialize(
td: tensordict.TensorDict,
layout: nemo_rl.data_plane.schema.Layout = 'padded',
pad_value_dict: dict[str, int | float] | None = None,
pad_to_seqlen: int = 0,
) BatchedDataDict[Any]#

Convert a wire TensorDict to a BatchedDataDict.

Trainer/worker code expects rectangular tensors — this is the bridge from the on-wire nested format.

The lazy BatchedDataDict import keeps import nemo_rl.data_plane cheap for unit tests that don’t actually call this function (BatchedDataDict transitively pulls multimodal deps like decord / torchvision).

Parameters:
  • td – Wire TensorDict to materialize.

  • layout

    "padded" (default) pads nested-tensor leaves via

    func:

    torch.nested.to_padded_tensor using pad_value_dict[k] (or 0 if unspecified); rectangular leaves pass through. "jagged" passes nested leaves through — use only when the caller knows how to consume them.

  • pad_value_dict – Per-field pad value used when layout='padded'.

  • pad_to_seqlen – When > 0, right-pad the seq dim up to this absolute length after to_padded_tensor. Worker-side _fetch passes its forward-pass target here (rounded up to sequence_length_round for Megatron’s microbatch iterator); driver-side read_columns leaves it 0 and consumes the natural-padded shape. Default 0 disables.

Returns:

BatchedDataDict with rectangular tensors for padded layout, nested tensors for jagged layout, and np.ndarray(dtype=object) for NonTensorStack leaves (TQ-native non-tensor passthrough).