nemo_rl.algorithms.xtoken_off_policy_distillation#
Single-teacher cross-tokenizer off-policy distillation.
Training-loop layout mirrors run_distillation.py /
nemo_rl/algorithms/distillation.py minus the on-policy bits (no env, no
rollout, no generation). Per step:
1. Pull a collated batch (student & teacher token ids + alignment).
2. Run teacher forward via ``Policy.get_topk_logits`` on TEACHER token
ids — gives top-k teacher logits at teacher positions.
3. Pack alignment payload + teacher topk into a student-side
``train_data`` dict.
4. ``student_policy.train(train_data, loss_fn)`` — student forward +
loss + backward + optimizer step happens inside the dtensor v2
worker.
The collator and aligner do all the CPU-side cross-tokenizer work; the loss function does only loss math; this module is just plumbing.
Module Contents#
Classes#
Top-level distillation algo config. |
|
Functions#
Construct cluster, dataloaders, policies, and loss fn for the run. |
|
Off-policy CT distillation training loop. |
|
Held-out KL/CE on a validation dataloader. |
Data#
API#
- nemo_rl.algorithms.xtoken_off_policy_distillation.XTOKEN_NON_STUDENT_SEQ_KEYS: frozenset[str]#
‘frozenset(…)’
- class nemo_rl.algorithms.xtoken_off_policy_distillation.OffPolicyDistillationConfig#
Bases:
typing.TypedDictTop-level distillation algo config.
.. attribute:: num_prompts_per_step
Global batch size at the dataloader level.
.. attribute:: max_num_steps
Max training steps before early stop.
.. attribute:: max_num_epochs
Max passes over the training dataset.
.. attribute:: seed
RNG seed.
.. attribute:: val_period
Validation cadence in steps.
0disables validation... attribute:: val_at_start
Run validation before training begins.
.. attribute:: val_at_end
Run validation on the final step.
Initialization
Initialize self. See help(type(self)) for accurate signature.
- num_prompts_per_step: int#
None
- max_num_steps: int#
None
- max_num_epochs: int#
None
- seed: int#
None
- val_period: int#
None
- val_at_start: bool#
None
- val_at_end: bool#
None
- class nemo_rl.algorithms.xtoken_off_policy_distillation.OffPolicyDistillationSaveState#
Bases:
typing.TypedDict- current_epoch: int#
None
- current_step: int#
None
- total_steps: int#
None
- consumed_samples: int#
None
- total_valid_tokens: int#
None
- val_loss: NotRequired[float]#
None
- nemo_rl.algorithms.xtoken_off_policy_distillation._default_off_policy_distillation_save_state() nemo_rl.algorithms.xtoken_off_policy_distillation.OffPolicyDistillationSaveState#
- class nemo_rl.algorithms.xtoken_off_policy_distillation.MasterConfig#
Bases:
pydantic.BaseModel- policy: nemo_rl.models.policy.PolicyConfig#
None
- teacher: nemo_rl.models.policy.PolicyConfig#
None
- data: nemo_rl.data.DataConfig#
None
- logger: nemo_rl.utils.logger.LoggerConfig#
None
- cluster: nemo_rl.distributed.virtual_cluster.ClusterConfig#
None
- checkpointing: nemo_rl.utils.checkpoint.CheckpointingConfig#
None
- nemo_rl.algorithms.xtoken_off_policy_distillation.setup(
- master_config: nemo_rl.algorithms.xtoken_off_policy_distillation.MasterConfig,
- student_tokenizer: transformers.tokenization_utils_base.PreTrainedTokenizerBase,
- teacher_tokenizer: transformers.tokenization_utils_base.PreTrainedTokenizerBase,
- train_dataset: nemo_rl.data.datasets.AllTaskProcessedDataset,
- val_dataset: Optional[nemo_rl.data.datasets.AllTaskProcessedDataset],
Construct cluster, dataloaders, policies, and loss fn for the run.
- nemo_rl.algorithms.xtoken_off_policy_distillation.xtoken_off_policy_distillation_train(
- student_policy: nemo_rl.models.policy.lm_policy.Policy,
- teacher_policy: nemo_rl.models.policy.lm_policy.Policy,
- dataloader: torchdata.stateful_dataloader.StatefulDataLoader,
- val_dataloader: Optional[torchdata.stateful_dataloader.StatefulDataLoader],
- loss_fn: nemo_rl.algorithms.loss.loss_functions.CrossTokenizerDistillationLossFn,
- logger: nemo_rl.utils.logger.Logger,
- checkpointer: nemo_rl.utils.checkpoint.CheckpointManager,
- off_policy_distillation_state: nemo_rl.algorithms.xtoken_off_policy_distillation.OffPolicyDistillationSaveState,
- master_config: nemo_rl.algorithms.xtoken_off_policy_distillation.MasterConfig,
Off-policy CT distillation training loop.
- nemo_rl.algorithms.xtoken_off_policy_distillation.validate(
- student_policy: nemo_rl.models.policy.lm_policy.Policy,
- teacher_policy: nemo_rl.models.policy.lm_policy.Policy,
- val_dataloader: torchdata.stateful_dataloader.StatefulDataLoader,
- loss_fn: nemo_rl.algorithms.loss.loss_functions.CrossTokenizerDistillationLossFn,
- master_config: nemo_rl.algorithms.xtoken_off_policy_distillation.MasterConfig,
- timer: Optional[nemo_rl.utils.timer.Timer] = None,
Held-out KL/CE on a validation dataloader.
Reuses the same per-step path as training, but in eval mode so no backward / optimizer step runs. Returns mean train-style metrics.