nemo_rl.algorithms.xtoken_off_policy_distillation#

Single-teacher cross-tokenizer off-policy distillation.

Training-loop layout mirrors run_distillation.py / nemo_rl/algorithms/distillation.py minus the on-policy bits (no env, no rollout, no generation). Per step:

1. Pull a collated batch (student & teacher token ids + alignment).
2. Run teacher forward via ``Policy.get_topk_logits`` on TEACHER token
   ids — gives top-k teacher logits at teacher positions.
3. Pack alignment payload + teacher topk into a student-side
   ``train_data`` dict.
4. ``student_policy.train(train_data, loss_fn)`` — student forward +
   loss + backward + optimizer step happens inside the dtensor v2
   worker.

The collator and aligner do all the CPU-side cross-tokenizer work; the loss function does only loss math; this module is just plumbing.

Module Contents#

Classes#

Functions#

_default_off_policy_distillation_save_state

setup

Construct cluster, dataloaders, policies, and loss fn for the run.

xtoken_off_policy_distillation_train

Off-policy CT distillation training loop.

validate

Held-out KL/CE on a validation dataloader.

Data#

API#

nemo_rl.algorithms.xtoken_off_policy_distillation.XTOKEN_NON_STUDENT_SEQ_KEYS: frozenset[str]#

‘frozenset(…)’

class nemo_rl.algorithms.xtoken_off_policy_distillation.OffPolicyDistillationConfig#

Bases: typing.TypedDict

Top-level distillation algo config.

.. attribute:: num_prompts_per_step

Global batch size at the dataloader level.

.. attribute:: max_num_steps

Max training steps before early stop.

.. attribute:: max_num_epochs

Max passes over the training dataset.

.. attribute:: seed

RNG seed.

.. attribute:: val_period

Validation cadence in steps. 0 disables validation.

.. attribute:: val_at_start

Run validation before training begins.

.. attribute:: val_at_end

Run validation on the final step.

Initialization

Initialize self. See help(type(self)) for accurate signature.

num_prompts_per_step: int#

None

max_num_steps: int#

None

max_num_epochs: int#

None

seed: int#

None

val_period: int#

None

val_at_start: bool#

None

val_at_end: bool#

None

class nemo_rl.algorithms.xtoken_off_policy_distillation.OffPolicyDistillationSaveState#

Bases: typing.TypedDict

current_epoch: int#

None

current_step: int#

None

total_steps: int#

None

consumed_samples: int#

None

total_valid_tokens: int#

None

val_loss: NotRequired[float]#

None

nemo_rl.algorithms.xtoken_off_policy_distillation._default_off_policy_distillation_save_state() nemo_rl.algorithms.xtoken_off_policy_distillation.OffPolicyDistillationSaveState#
class nemo_rl.algorithms.xtoken_off_policy_distillation.MasterConfig#

Bases: pydantic.BaseModel

policy: nemo_rl.models.policy.PolicyConfig#

None

teacher: nemo_rl.models.policy.PolicyConfig#

None

loss_fn: nemo_rl.algorithms.loss.loss_functions.CrossTokenizerDistillationLossConfig#

None

data: nemo_rl.data.DataConfig#

None

distillation: nemo_rl.algorithms.xtoken_off_policy_distillation.OffPolicyDistillationConfig#

None

logger: nemo_rl.utils.logger.LoggerConfig#

None

cluster: nemo_rl.distributed.virtual_cluster.ClusterConfig#

None

checkpointing: nemo_rl.utils.checkpoint.CheckpointingConfig#

None

nemo_rl.algorithms.xtoken_off_policy_distillation.setup(
master_config: nemo_rl.algorithms.xtoken_off_policy_distillation.MasterConfig,
student_tokenizer: transformers.tokenization_utils_base.PreTrainedTokenizerBase,
teacher_tokenizer: transformers.tokenization_utils_base.PreTrainedTokenizerBase,
train_dataset: nemo_rl.data.datasets.AllTaskProcessedDataset,
val_dataset: Optional[nemo_rl.data.datasets.AllTaskProcessedDataset],
) tuple[nemo_rl.models.policy.lm_policy.Policy, nemo_rl.models.policy.lm_policy.Policy, torchdata.stateful_dataloader.StatefulDataLoader, Optional[torchdata.stateful_dataloader.StatefulDataLoader], nemo_rl.algorithms.loss.loss_functions.CrossTokenizerDistillationLossFn, nemo_rl.utils.logger.Logger, nemo_rl.utils.checkpoint.CheckpointManager, nemo_rl.algorithms.xtoken_off_policy_distillation.OffPolicyDistillationSaveState, nemo_rl.algorithms.xtoken_off_policy_distillation.MasterConfig]#

Construct cluster, dataloaders, policies, and loss fn for the run.

nemo_rl.algorithms.xtoken_off_policy_distillation.xtoken_off_policy_distillation_train(
student_policy: nemo_rl.models.policy.lm_policy.Policy,
teacher_policy: nemo_rl.models.policy.lm_policy.Policy,
dataloader: torchdata.stateful_dataloader.StatefulDataLoader,
val_dataloader: Optional[torchdata.stateful_dataloader.StatefulDataLoader],
loss_fn: nemo_rl.algorithms.loss.loss_functions.CrossTokenizerDistillationLossFn,
logger: nemo_rl.utils.logger.Logger,
checkpointer: nemo_rl.utils.checkpoint.CheckpointManager,
off_policy_distillation_state: nemo_rl.algorithms.xtoken_off_policy_distillation.OffPolicyDistillationSaveState,
master_config: nemo_rl.algorithms.xtoken_off_policy_distillation.MasterConfig,
) None#

Off-policy CT distillation training loop.

nemo_rl.algorithms.xtoken_off_policy_distillation.validate(
student_policy: nemo_rl.models.policy.lm_policy.Policy,
teacher_policy: nemo_rl.models.policy.lm_policy.Policy,
val_dataloader: torchdata.stateful_dataloader.StatefulDataLoader,
loss_fn: nemo_rl.algorithms.loss.loss_functions.CrossTokenizerDistillationLossFn,
master_config: nemo_rl.algorithms.xtoken_off_policy_distillation.MasterConfig,
timer: Optional[nemo_rl.utils.timer.Timer] = None,
) tuple[dict[str, Any], dict[str, Any]]#

Held-out KL/CE on a validation dataloader.

Reuses the same per-step path as training, but in eval mode so no backward / optimizer step runs. Returns mean train-style metrics.