nemo_rl.models.policy.tq_policy#
TQ-mediated Policy: meta-driven 1-hop counterpart to Policy.
Exposes train_from_meta / get_logprobs_from_meta /
get_reference_policy_logprobs_from_meta — same return shapes as
Policy.{train, get_logprobs, get_reference_policy_logprobs} but
accepting a KVBatchMeta instead of a BatchedDataDict. The meta
names per-sample TQ keys minted once at rollout
(:class:nemo_rl.experience.sync_rollout_actor.SyncRolloutActor); each
dispatch slices the key list per DP rank via
- func:
nemo_rl.data_plane.preshard.shard_meta_for_dp(no re-fan-out, no key minting). Workers fetch their slice from TQ viaself._fetch(meta)and write deltas back viaself._write_back_result_field(...). Seenemo_rl/data_plane/README.mdfor the full design.
Module Contents#
Classes#
TQ-mediated counterpart to :class: |
Functions#
API#
- nemo_rl.models.policy.tq_policy._aggregate_train_results(
- results: list[dict[str, Any]],
- class nemo_rl.models.policy.tq_policy.TQPolicy(
- *args: Any,
- dp_cfg: dict[str, Any],
- tq_partition_id: str = 'train',
- **kwargs: Any,
Bases:
nemo_rl.models.policy.lm_policy.PolicyTQ-mediated counterpart to :class:
Policy.Constructor accepts an additional
dp_cfg(themaster_config["data_plane"]dict). Bootstraps the controller on the driver and forwardssetup_data_plane(dp_cfg)to every worker so they can attach as clients (bootstrap=False).The partition lifecycle (
register_partition/clear_samples) is the trainer’s responsibility — this class assumes the partition namedself.tq_partition_id(default"train") is open with a schema coveringDP_TRAIN_FIELDS(the bulk schema written by the rollout actor at first put + driver-/worker-written deltas).Initialization
- shutdown() bool#
Close the TQ client before shutting down the worker group.
- prepare_step(
- num_samples: int,
- group_size: Optional[int] = None,
Register the per-step TQ partition.
Sync trainers call this at the start of each step. The static partition id
"train"is cleared and reused across steps. The schema is the union of all consumer fields — producers write only the subset they have, consumers fetch viaselect_fields.- Parameters:
num_samples – Expected total samples this step.
group_size – GRPO group size for balanced sampling;
Nonedisables grouping.
- prepare_val_partition(
- num_samples: int,
- *,
- partition_id: str = 'val',
Register a per-batch val partition (single consumer, no GRPO grouping).
Sync val trainers call this at the start of each val batch. Distinct from :meth:
prepare_stepbecause val has its own partition id and a single consumer task.
- discard_samples(sample_ids: list[str], partition_id: str) None#
Drop a set of uids from TQ.
Used both for step-end teardown (via :meth:
finish_step) and mid-step filtering (e.g. dynamic sampling).
- finish_step(meta: nemo_rl.data_plane.KVBatchMeta) None#
Drop this step’s bulk from TQ. Mirror of :meth:
prepare_step.
- _stamp_pad_seqlen(meta: nemo_rl.data_plane.KVBatchMeta) None#
Mint
GLOBAL_FORWARD_PAD_SEQLENontometa.extra_info(idempotent).Cross-DP forward pad target. Preshard shards inherit it via
dict(meta.extra_info)propagation.
- read_from_dataplane(
- meta: nemo_rl.data_plane.KVBatchMeta,
- *,
- select_fields: list[str],
- pad_value_dict: Optional[dict[str, Any]] = None,
Fetch + materialize columns from the data plane (TQ).
read_columnspads tometa.extra_info[GLOBAL_FORWARD_PAD_SEQLEN]— the same value workers pad to in their forward pass. Driver and workers thus return columns at one identical seq dim, with no driver-side knowledge ofsequence_length_round.
- write_to_dataplane(
- meta: nemo_rl.data_plane.KVBatchMeta,
- fields: dict[str, Any],
Write driver-computed columns to the data plane (TQ).
- _packing_args(
- mb_tokens_key: str,
Resolve (sequence_packing_args, dynamic_batching_args) for a given stage.
The stage is identified by
mb_tokens_key("logprob_mb_tokens"or"train_mb_tokens").
- _logprob_dispatch(
- meta: nemo_rl.data_plane.KVBatchMeta,
- *,
- task_name: str,
- worker_method: str,
- timer_prefix: str,
- timer: Optional[nemo_rl.utils.timer.Timer],
- common_kwargs: dict[str, Any],
Shared body of get_logprobs_from_meta / get_reference_policy_logprobs_from_meta.
Logprob workers need only LP_SEED_FIELDS — narrow the meta’s field list so
_fetchdoesn’t pull rollout-only payload (e.g. multimodal). The same shape is used for both prev_lp and ref_lp. Workers compute the per-token tensor and commit it to TQ via the leader-rank_write_back_result_field; the Ray return is always None, so this dispatcher just waits for completion.
- get_logprobs_from_meta(
- meta: nemo_rl.data_plane.KVBatchMeta,
- micro_batch_size: Optional[int] = None,
- timer: Optional[nemo_rl.utils.timer.Timer] = None,
- get_reference_policy_logprobs_from_meta(
- meta: nemo_rl.data_plane.KVBatchMeta,
- micro_batch_size: Optional[int] = None,
- timer: Optional[nemo_rl.utils.timer.Timer] = None,
- train_from_meta(
- meta: nemo_rl.data_plane.KVBatchMeta,
- loss_fn: nemo_rl.algorithms.loss.interfaces.LossFunction,
- eval_mode: bool = False,
- gbs: Optional[int] = None,
- mbs: Optional[int] = None,
- timer: Optional[nemo_rl.utils.timer.Timer] = None,
1-hop counterpart to :meth:
train.metanames per-sample keys; columns written by the rollout actor + worker logprob deltas + driver-side advantage delta have all landed under the same keys at this point. Workers fetch the union viatrain_presharded→self._fetch(meta). No partition drain here — sync 1-hop’s trainer callsclear_samplesonce at end of step.- Parameters:
meta – Full-step
KVBatchMeta(consumed by all DP ranks).gbs – Global batch size; defaults to
cfg["train_global_batch_size"].mbs – Micro batch size; defaults to
cfg["train_micro_batch_size"].timer – Optional timer for nested
policy_training/*measurements.
- Returns:
Aggregated training-step output dict.