nemo_rl.algorithms.distillation
#
Module Contents#
Classes#
Main configuration structure. |
Functions#
Check if the vocab of the tokenizer (student) and the teacher tokenizer are equal. |
|
Main entry point for distillation algorithm. |
|
Run Distillation training algorithm. |
|
Run validation on the validation dataset. |
Data#
API#
- nemo_rl.algorithms.distillation.TokenizerType#
βTypeVar(β¦)β
- class nemo_rl.algorithms.distillation.DistillationConfig#
Bases:
typing.TypedDict
- num_prompts_per_step: int#
None
- num_generations_per_prompt: int#
None
- max_rollout_turns: int#
None
- max_num_steps: int#
None
- val_batch_size: int#
None
- val_period: int#
None
- val_at_start: bool#
None
- max_val_samples: int#
None
- topk_logits_k: int#
None
- seed: int#
None
- class nemo_rl.algorithms.distillation.DistillationSaveState#
Bases:
typing.TypedDict
- step: int#
None
- val_reward: NotRequired[float]#
None
- consumed_samples: int#
None
- total_valid_tokens: int#
None
- nemo_rl.algorithms.distillation._default_distillation_save_state() nemo_rl.algorithms.distillation.DistillationSaveState #
- class nemo_rl.algorithms.distillation.MasterConfig#
Bases:
typing.TypedDict
Main configuration structure.
Initialization
Initialize self. See help(type(self)) for accurate signature.
- policy: nemo_rl.models.policy.PolicyConfig#
None
- teacher: nemo_rl.models.policy.PolicyConfig#
None
- env: dict[str, Any]#
None
- data: nemo_rl.data.DataConfig#
None
- distillation: nemo_rl.algorithms.distillation.DistillationConfig#
None
- logger: nemo_rl.utils.logger.LoggerConfig#
None
- cluster: nemo_rl.distributed.virtual_cluster.ClusterConfig#
None
- checkpointing: nemo_rl.utils.checkpoint.CheckpointingConfig#
None
- nemo_rl.algorithms.distillation.check_vocab_equality(
- tokenizer: nemo_rl.algorithms.distillation.TokenizerType,
- student_model_name: str,
- teacher_model_name: str,
Check if the vocab of the tokenizer (student) and the teacher tokenizer are equal.
- nemo_rl.algorithms.distillation.setup(
- master_config: nemo_rl.algorithms.distillation.MasterConfig,
- tokenizer: nemo_rl.algorithms.distillation.TokenizerType,
- train_dataset: nemo_rl.data.datasets.AllTaskProcessedDataset,
- val_dataset: Optional[nemo_rl.data.datasets.AllTaskProcessedDataset],
Main entry point for distillation algorithm.
- Returns:
tuple of student_policy, teacher_policy, student_generation, train_dataloader, val_dataloader, loss_fn, logger, checkpointer, distillation_save_state, master_config
- nemo_rl.algorithms.distillation.distillation_train(
- student_policy: nemo_rl.models.policy.interfaces.ColocatablePolicyInterface,
- teacher_policy: nemo_rl.models.policy.interfaces.ColocatablePolicyInterface,
- student_generation: Optional[nemo_rl.models.generation.interfaces.GenerationInterface],
- dataloader: torchdata.stateful_dataloader.StatefulDataLoader,
- val_dataloader: Optional[torchdata.stateful_dataloader.StatefulDataLoader],
- tokenizer: nemo_rl.algorithms.distillation.TokenizerType,
- loss_fn: nemo_rl.algorithms.loss_functions.DistillationLossFn,
- task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface],
- val_task_to_env: Optional[dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]],
- logger: nemo_rl.utils.logger.Logger,
- checkpointer: nemo_rl.utils.checkpoint.CheckpointManager,
- distillation_save_state: nemo_rl.algorithms.distillation.DistillationSaveState,
- master_config: nemo_rl.algorithms.distillation.MasterConfig,
Run Distillation training algorithm.
- nemo_rl.algorithms.distillation.validate(
- policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface,
- val_dataloader: Optional[torchdata.stateful_dataloader.StatefulDataLoader],
- tokenizer,
- val_task_to_env: Optional[dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]],
- step: int,
- master_config: nemo_rl.algorithms.distillation.MasterConfig,
Run validation on the validation dataset.