nemo_rl.data_plane.codec#
Wire <-> trainer codec — jagged-on-the-wire bridge.
Writer side: variable-length fields are encoded as
torch.nested.nested_tensorwithlayout=torch.jaggedbeforeput_samples. Padding tax is paid only when a consumer needs a rectangular tensor.Reader side: :func:
materializeaccepts the wire TensorDict and, whenlayout='padded', calls
- func:
torch.nested.to_padded_tensoron any nested leaves using the per-field padding value supplied inpad_value_dict. Trainer code consumes the padded BatchedDataDict unchanged.
Worker write-backs that produce
response-shaped outputs use
- func:
response_from_nestedto extract the response slice from a (prompt+response) nested tensor.
Non-tensor object fields ride as
NonTensorStack/NonTensorDataleaves (TQ-native passthrough). :func:materializedecodes them back tonp.ndarray(dtype=object)for the trainer.
Module Contents#
Functions#
Strip right-padding off a rectangular tensor using per-row lengths. |
|
Stack equal-shape rows; reconstruct as jagged nested when ragged. |
|
Recover the payload of a possibly wire-stripped |
|
Convert |
|
Pack a column dict into the wire layout expected by |
|
Force-jaggedize a known per-token field, tolerating SP padding. |
|
Extract the response slice from a (prompt+response) nested tensor. |
|
Convert a wire TensorDict to a BatchedDataDict. |
API#
- nemo_rl.data_plane.codec.to_nested_by_length(
- padded: torch.Tensor,
- lengths: torch.Tensor,
Strip right-padding off a rectangular tensor using per-row lengths.
Used by the producer side: convert
- Func:
batched_message_log_to_flat_messageoutput (already padded) into the wire format beforeput_samples.- Parameters:
padded – Rectangular tensor of shape
(N, S, ...).lengths – Per-row valid lengths, shape
(N,). CUDA tensors are moved to CPU once to avoid per-row syncs.
- Returns:
A
torch.jaggednested tensor whose i-th row ispadded[i, :lengths[i], ...].
- nemo_rl.data_plane.codec.stack_or_nest(tensors: list[torch.Tensor]) torch.Tensor#
Stack equal-shape rows; reconstruct as jagged nested when ragged.
- Parameters:
tensors – Per-row tensors; assumed to share leading dims modulo an optional ragged seq dim. Empty list returns
torch.empty(0).- Returns:
A regular tensor when all rows share shape; otherwise a
torch.jaggednested tensor.
- nemo_rl.data_plane.codec.unwrap_wire_stripped_payload(item: Any) Any#
Recover the payload of a possibly wire-stripped
NonTensorData.TQ’s
MsgpackEncoder._encode_tensordictserializes anyTensorDictBaseviadict(obj.items())— only the tensor backing dict.NonTensorDatastores its payload in_non_tensordict["data"], so it round-trips through ZMQ as an emptyTensorDict({}, batch_size=[]). We map only that exact signature toNone; any otherTensorDictBase(with tensor fields, non-scalar batch, or a salvageable_non_tensordictpayload) passes through unchanged so we never drop real data.
- nemo_rl.data_plane.codec.maybe_pack_jagged(
- val: torch.Tensor,
- lengths: torch.Tensor,
Convert
valto jagged iff it looks like a per-token field.Used by every write site (initial put, driver delta-write, worker write-back) so all per-token fields land in TQ as jagged with the same row lengths — read-time materialization then pads them all to the same target shape, avoiding shape-mismatch crashes between mixed wire formats.
- Parameters:
val – Tensor to consider. Qualifies for jagged conversion only when
val.shape == (N, max(lengths), ...)whereN == lengths.shape[0].lengths – Per-row valid lengths, shape
(N,).
- Returns:
A
torch.jaggednested tensor when the shape heuristic matches; otherwisevalpassed through as a rectangular tensor.
- nemo_rl.data_plane.codec.pack_jagged_fields(
- fields: dict[str, torch.Tensor | np.ndarray],
- *,
- lengths: torch.Tensor | None,
Pack a column dict into the wire layout expected by
put_samples.Zero-copy where possible: per-token tensors that match
(N, max(lengths), ...)becometorch.jaggedviews via- Func:
maybe_pack_jagged; non-conforming tensors pass through rectangular;np.ndarray(dtype=object)is forwarded as-is. This is a layout transform, not serialization — the on-wire bytes are produced later by the TQ backend’s msgpack encoder. Centralizing the transform here makes it the single source of truth for both- Func:
kv_first_writeand :func:write_columns.- Parameters:
fields – Column name → tensor or object array. Other value types raise
TypeError.lengths – Per-row valid lengths used by :func:
maybe_pack_jaggedto decide whether a tensor qualifies for jagged conversion.Nonedisables jagged conversion entirely (every tensor passes through rectangular).
- Returns:
TensorDictwithbatch_size=[N](N fromlengthsif given, else 0) ready forput_samples.
- nemo_rl.data_plane.codec.pack_per_token_field(
- val: torch.Tensor,
- lengths: torch.Tensor,
Force-jaggedize a known per-token field, tolerating SP padding.
Unlike :func:
maybe_pack_jagged(which is shape-strict to avoid false positives on 3D extras like image features), this function is invoked at write-back sites where the caller already knows the field is per-token (e.g.prev_logprobs,reference_policy_logprobs). mcore SP rounds the forward output’s seq dim up to a multiple of TP, so the value can be 1+ tokens wider thanmax(lengths); :func:to_nested_by_lengthslices each row to its own length and drops the trailing SP padding cleanly.- Parameters:
val – Per-token tensor. Falls back to rectangular when it cannot be jaggedized (wrong batch dim, < 2D, or seq dim shorter than
max(lengths)).lengths – Per-row valid lengths, shape
(N,).
- Returns:
A
torch.jaggednested tensor when the shape allows; otherwisevalpassed through as a rectangular tensor.
- nemo_rl.data_plane.codec.response_from_nested(
- full: torch.Tensor,
- response_mask: torch.Tensor,
Extract the response slice from a (prompt+response) nested tensor.
Used on the worker side for logprob / ref-logprob write-back where only the response-token slice is interesting downstream. The “left-shift by one token” convention is applied (so logprobs at output position i correspond to the prediction of input token i+1).
- Parameters:
full – Jagged nested tensor of shape
(N, prompt_len + response_len).response_mask – Jagged nested tensor of shape
(N, response_len); itsoffsets().diff()gives the per-row response length.
- Returns:
Jagged nested tensor of shape
(N, response_len)containing the left-shifted response slice.
- nemo_rl.data_plane.codec.materialize(
- td: tensordict.TensorDict,
- layout: nemo_rl.data_plane.schema.Layout = 'padded',
- pad_value_dict: dict[str, int | float] | None = None,
- pad_to_seqlen: int = 0,
Convert a wire TensorDict to a BatchedDataDict.
Trainer/worker code expects rectangular tensors — this is the bridge from the on-wire nested format.
The lazy
BatchedDataDictimport keepsimport nemo_rl.data_planecheap for unit tests that don’t actually call this function (BatchedDataDicttransitively pulls multimodal deps like decord / torchvision).- Parameters:
td – Wire TensorDict to materialize.
layout –
"padded"(default) pads nested-tensor leaves via- func:
torch.nested.to_padded_tensorusingpad_value_dict[k](or 0 if unspecified); rectangular leaves pass through."jagged"passes nested leaves through — use only when the caller knows how to consume them.
pad_value_dict – Per-field pad value used when
layout='padded'.pad_to_seqlen – When > 0, right-pad the seq dim up to this absolute length after
to_padded_tensor. Worker-side_fetchpasses its forward-pass target here (rounded up tosequence_length_roundfor Megatron’s microbatch iterator); driver-sideread_columnsleaves it 0 and consumes the natural-padded shape. Default 0 disables.
- Returns:
BatchedDataDictwith rectangular tensors for padded layout, nested tensors for jagged layout, andnp.ndarray(dtype=object)forNonTensorStackleaves (TQ-native non-tensor passthrough).