nemo_rl.algorithms.distillation#

Module Contents#

Classes#

DistillationConfig

DistillationSaveState

MasterConfig

Main configuration structure.

Functions#

_default_distillation_save_state

check_vocab_equality

Check if the vocab of the tokenizer (student) and the teacher tokenizer are equal.

setup

Main entry point for distillation algorithm.

distillation_train

Run Distillation training algorithm.

validate

Run validation on the validation dataset.

Data#

API#

nemo_rl.algorithms.distillation.TokenizerType#

β€˜TypeVar(…)’

class nemo_rl.algorithms.distillation.DistillationConfig#

Bases: typing.TypedDict

num_prompts_per_step: int#

None

num_generations_per_prompt: int#

None

max_rollout_turns: int#

None

max_num_steps: int#

None

val_batch_size: int#

None

val_period: int#

None

val_at_start: bool#

None

max_val_samples: int#

None

topk_logits_k: int#

None

seed: int#

None

class nemo_rl.algorithms.distillation.DistillationSaveState#

Bases: typing.TypedDict

step: int#

None

val_reward: NotRequired[float]#

None

consumed_samples: int#

None

total_valid_tokens: int#

None

nemo_rl.algorithms.distillation._default_distillation_save_state() nemo_rl.algorithms.distillation.DistillationSaveState#
class nemo_rl.algorithms.distillation.MasterConfig#

Bases: typing.TypedDict

Main configuration structure.

Initialization

Initialize self. See help(type(self)) for accurate signature.

policy: nemo_rl.models.policy.PolicyConfig#

None

teacher: nemo_rl.models.policy.PolicyConfig#

None

loss_fn: nemo_rl.algorithms.loss_functions.DistillationLossConfig#

None

env: dict[str, Any]#

None

data: nemo_rl.data.DataConfig#

None

distillation: nemo_rl.algorithms.distillation.DistillationConfig#

None

logger: nemo_rl.utils.logger.LoggerConfig#

None

cluster: nemo_rl.distributed.virtual_cluster.ClusterConfig#

None

checkpointing: nemo_rl.utils.checkpoint.CheckpointingConfig#

None

nemo_rl.algorithms.distillation.check_vocab_equality(
tokenizer: nemo_rl.algorithms.distillation.TokenizerType,
student_model_name: str,
teacher_model_name: str,
) None#

Check if the vocab of the tokenizer (student) and the teacher tokenizer are equal.

nemo_rl.algorithms.distillation.setup(
master_config: nemo_rl.algorithms.distillation.MasterConfig,
tokenizer: nemo_rl.algorithms.distillation.TokenizerType,
train_dataset: nemo_rl.data.datasets.AllTaskProcessedDataset,
val_dataset: Optional[nemo_rl.data.datasets.AllTaskProcessedDataset],
) tuple[nemo_rl.models.policy.interfaces.ColocatablePolicyInterface, nemo_rl.models.policy.interfaces.ColocatablePolicyInterface, Optional[nemo_rl.models.generation.interfaces.GenerationInterface], torchdata.stateful_dataloader.StatefulDataLoader, Optional[torchdata.stateful_dataloader.StatefulDataLoader], nemo_rl.algorithms.loss_functions.DistillationLossFn, nemo_rl.utils.logger.Logger, nemo_rl.utils.checkpoint.CheckpointManager, nemo_rl.algorithms.distillation.DistillationSaveState, nemo_rl.algorithms.distillation.MasterConfig]#

Main entry point for distillation algorithm.

Returns:

tuple of student_policy, teacher_policy, student_generation, train_dataloader, val_dataloader, loss_fn, logger, checkpointer, distillation_save_state, master_config

nemo_rl.algorithms.distillation.distillation_train(
student_policy: nemo_rl.models.policy.interfaces.ColocatablePolicyInterface,
teacher_policy: nemo_rl.models.policy.interfaces.ColocatablePolicyInterface,
student_generation: Optional[nemo_rl.models.generation.interfaces.GenerationInterface],
dataloader: torchdata.stateful_dataloader.StatefulDataLoader,
val_dataloader: Optional[torchdata.stateful_dataloader.StatefulDataLoader],
tokenizer: nemo_rl.algorithms.distillation.TokenizerType,
loss_fn: nemo_rl.algorithms.loss_functions.DistillationLossFn,
task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface],
val_task_to_env: Optional[dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]],
logger: nemo_rl.utils.logger.Logger,
checkpointer: nemo_rl.utils.checkpoint.CheckpointManager,
distillation_save_state: nemo_rl.algorithms.distillation.DistillationSaveState,
master_config: nemo_rl.algorithms.distillation.MasterConfig,
) None#

Run Distillation training algorithm.

nemo_rl.algorithms.distillation.validate(
policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface,
val_dataloader: Optional[torchdata.stateful_dataloader.StatefulDataLoader],
tokenizer,
val_task_to_env: Optional[dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]],
step: int,
master_config: nemo_rl.algorithms.distillation.MasterConfig,
) tuple[dict[str, Any], dict[str, Any]]#

Run validation on the validation dataset.