nemo_rl.models.policy.tq_policy#

TQ-mediated Policy: meta-driven 1-hop counterpart to Policy.

Exposes train_from_meta / get_logprobs_from_meta / get_reference_policy_logprobs_from_meta — same return shapes as Policy.{train, get_logprobs, get_reference_policy_logprobs} but accepting a KVBatchMeta instead of a BatchedDataDict. The meta names per-sample TQ keys minted once at rollout (:class:nemo_rl.experience.sync_rollout_actor.SyncRolloutActor); each dispatch slices the key list per DP rank via

func:

nemo_rl.data_plane.preshard.shard_meta_for_dp (no re-fan-out, no key minting). Workers fetch their slice from TQ via self._fetch(meta) and write deltas back via self._write_back_result_field(...). See nemo_rl/data_plane/README.md for the full design.

Module Contents#

Classes#

TQPolicy

TQ-mediated counterpart to :class:Policy.

Functions#

API#

nemo_rl.models.policy.tq_policy._aggregate_train_results(
results: list[dict[str, Any]],
) dict[str, Any]#
class nemo_rl.models.policy.tq_policy.TQPolicy(
*args: Any,
dp_cfg: dict[str, Any],
tq_partition_id: str = 'train',
**kwargs: Any,
)#

Bases: nemo_rl.models.policy.lm_policy.Policy

TQ-mediated counterpart to :class:Policy.

Constructor accepts an additional dp_cfg (the master_config["data_plane"] dict). Bootstraps the controller on the driver and forwards setup_data_plane(dp_cfg) to every worker so they can attach as clients (bootstrap=False).

The partition lifecycle (register_partition / clear_samples) is the trainer’s responsibility — this class assumes the partition named self.tq_partition_id (default "train") is open with a schema covering DP_TRAIN_FIELDS (the bulk schema written by the rollout actor at first put + driver-/worker-written deltas).

Initialization

shutdown() bool#

Close the TQ client before shutting down the worker group.

prepare_step(
num_samples: int,
group_size: Optional[int] = None,
) None#

Register the per-step TQ partition.

Sync trainers call this at the start of each step. The static partition id "train" is cleared and reused across steps. The schema is the union of all consumer fields — producers write only the subset they have, consumers fetch via select_fields.

Parameters:
  • num_samples – Expected total samples this step.

  • group_size – GRPO group size for balanced sampling; None disables grouping.

prepare_val_partition(
num_samples: int,
*,
partition_id: str = 'val',
) None#

Register a per-batch val partition (single consumer, no GRPO grouping).

Sync val trainers call this at the start of each val batch. Distinct from :meth:prepare_step because val has its own partition id and a single consumer task.

discard_samples(sample_ids: list[str], partition_id: str) None#

Drop a set of uids from TQ.

Used both for step-end teardown (via :meth:finish_step) and mid-step filtering (e.g. dynamic sampling).

finish_step(meta: nemo_rl.data_plane.KVBatchMeta) None#

Drop this step’s bulk from TQ. Mirror of :meth:prepare_step.

_stamp_pad_seqlen(meta: nemo_rl.data_plane.KVBatchMeta) None#

Mint GLOBAL_FORWARD_PAD_SEQLEN onto meta.extra_info (idempotent).

Cross-DP forward pad target. Preshard shards inherit it via dict(meta.extra_info) propagation.

read_from_dataplane(
meta: nemo_rl.data_plane.KVBatchMeta,
*,
select_fields: list[str],
pad_value_dict: Optional[dict[str, Any]] = None,
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any]#

Fetch + materialize columns from the data plane (TQ).

read_columns pads to meta.extra_info[GLOBAL_FORWARD_PAD_SEQLEN] — the same value workers pad to in their forward pass. Driver and workers thus return columns at one identical seq dim, with no driver-side knowledge of sequence_length_round.

write_to_dataplane(
meta: nemo_rl.data_plane.KVBatchMeta,
fields: dict[str, Any],
) None#

Write driver-computed columns to the data plane (TQ).

_packing_args(
mb_tokens_key: str,
) tuple[Optional[dict[str, Any]], Optional[dict[str, Any]]]#

Resolve (sequence_packing_args, dynamic_batching_args) for a given stage.

The stage is identified by mb_tokens_key ("logprob_mb_tokens" or "train_mb_tokens").

_logprob_dispatch(
meta: nemo_rl.data_plane.KVBatchMeta,
*,
task_name: str,
worker_method: str,
timer_prefix: str,
timer: Optional[nemo_rl.utils.timer.Timer],
common_kwargs: dict[str, Any],
) None#

Shared body of get_logprobs_from_meta / get_reference_policy_logprobs_from_meta.

Logprob workers need only LP_SEED_FIELDS — narrow the meta’s field list so _fetch doesn’t pull rollout-only payload (e.g. multimodal). The same shape is used for both prev_lp and ref_lp. Workers compute the per-token tensor and commit it to TQ via the leader-rank _write_back_result_field; the Ray return is always None, so this dispatcher just waits for completion.

get_logprobs_from_meta(
meta: nemo_rl.data_plane.KVBatchMeta,
micro_batch_size: Optional[int] = None,
timer: Optional[nemo_rl.utils.timer.Timer] = None,
) None#
get_reference_policy_logprobs_from_meta(
meta: nemo_rl.data_plane.KVBatchMeta,
micro_batch_size: Optional[int] = None,
timer: Optional[nemo_rl.utils.timer.Timer] = None,
) None#
train_from_meta(
meta: nemo_rl.data_plane.KVBatchMeta,
loss_fn: nemo_rl.algorithms.loss.interfaces.LossFunction,
eval_mode: bool = False,
gbs: Optional[int] = None,
mbs: Optional[int] = None,
timer: Optional[nemo_rl.utils.timer.Timer] = None,
) dict[str, Any]#

1-hop counterpart to :meth:train.

meta names per-sample keys; columns written by the rollout actor + worker logprob deltas + driver-side advantage delta have all landed under the same keys at this point. Workers fetch the union via train_preshardedself._fetch(meta). No partition drain here — sync 1-hop’s trainer calls clear_samples once at end of step.

Parameters:
  • meta – Full-step KVBatchMeta (consumed by all DP ranks).

  • gbs – Global batch size; defaults to cfg["train_global_batch_size"].

  • mbs – Micro batch size; defaults to cfg["train_micro_batch_size"].

  • timer – Optional timer for nested policy_training/* measurements.

Returns:

Aggregated training-step output dict.