nemo_rl.algorithms.dpo
#
Module Contents#
Classes#
Functions#
Main entry point for running DPO algorithm. |
|
Run validation on the validation dataset. |
|
API#
- class nemo_rl.algorithms.dpo.DPOSaveState[source]#
Bases:
typing.TypedDict
- epoch: int#
None
- step: int#
None
- total_steps: int#
None
- val_loss: float#
None
- consumed_samples: int#
None
- nemo_rl.algorithms.dpo._default_dpo_save_state() nemo_rl.algorithms.dpo.DPOSaveState [source]#
- class nemo_rl.algorithms.dpo.DPOConfig[source]#
Bases:
typing.TypedDict
- max_num_epochs: int#
None
- max_num_steps: int#
None
- val_period: int#
None
- val_batches: int#
None
- val_global_batch_size: int#
None
- val_micro_batch_size: int#
None
- val_at_start: bool#
None
- seed: int#
None
- reference_policy_kl_penalty: float#
None
- preference_average_log_probs: bool#
None
- sft_average_log_probs: bool#
None
- preference_loss_weight: float#
None
- sft_loss_weight: float#
None
- class nemo_rl.algorithms.dpo.MasterConfig[source]#
Bases:
typing.TypedDict
- policy: nemo_rl.models.policy.PolicyConfig#
None
- data: nemo_rl.data.DataConfig#
None
- dpo: nemo_rl.algorithms.dpo.DPOConfig#
None
- logger: nemo_rl.utils.logger.LoggerConfig#
None
- cluster: nemo_rl.distributed.virtual_cluster.ClusterConfig#
None
- checkpointing: nemo_rl.utils.checkpoint.CheckpointingConfig#
None
- nemo_rl.algorithms.dpo.setup(
- master_config: nemo_rl.algorithms.dpo.MasterConfig,
- tokenizer: transformers.AutoTokenizer,
- train_dataset: nemo_rl.data.datasets.AllTaskProcessedDataset,
- val_dataset: nemo_rl.data.datasets.AllTaskProcessedDataset,
Main entry point for running DPO algorithm.
- Returns:
Tuple of policy, cluster, dataloader, tokenizer, loss_fn, math_env, master_config, logger
- nemo_rl.algorithms.dpo.add_ref_logprobs_to_data(
- dataloader,
- policy,
- master_config,
- is_val=False,
- nemo_rl.algorithms.dpo.validate(
- policy: nemo_rl.models.interfaces.PolicyInterface,
- val_dataloader: torchdata.stateful_dataloader.StatefulDataLoader,
- tokenizer,
- loss_fn,
- step: int,
- master_config: nemo_rl.algorithms.dpo.MasterConfig,
- val_batches: int,
- val_batch_size: int,
- val_mbs: int,
Run validation on the validation dataset.