nemo_rl.algorithms.dpo#

Module Contents#

Classes#

Functions#

_default_dpo_save_state

setup

Main entry point for running DPO algorithm.

add_ref_logprobs_to_data

validate

Run validation on the validation dataset.

dpo_train

API#

class nemo_rl.algorithms.dpo.DPOSaveState[source]#

Bases: typing.TypedDict

epoch: int#

None

step: int#

None

total_steps: int#

None

val_loss: float#

None

consumed_samples: int#

None

nemo_rl.algorithms.dpo._default_dpo_save_state() nemo_rl.algorithms.dpo.DPOSaveState[source]#
class nemo_rl.algorithms.dpo.DPOConfig[source]#

Bases: typing.TypedDict

max_num_epochs: int#

None

max_num_steps: int#

None

val_period: int#

None

val_batches: int#

None

val_global_batch_size: int#

None

val_micro_batch_size: int#

None

val_at_start: bool#

None

seed: int#

None

reference_policy_kl_penalty: float#

None

preference_average_log_probs: bool#

None

sft_average_log_probs: bool#

None

preference_loss_weight: float#

None

sft_loss_weight: float#

None

class nemo_rl.algorithms.dpo.MasterConfig[source]#

Bases: typing.TypedDict

policy: nemo_rl.models.policy.PolicyConfig#

None

data: nemo_rl.data.DataConfig#

None

dpo: nemo_rl.algorithms.dpo.DPOConfig#

None

logger: nemo_rl.utils.logger.LoggerConfig#

None

cluster: nemo_rl.distributed.virtual_cluster.ClusterConfig#

None

checkpointing: nemo_rl.utils.checkpoint.CheckpointingConfig#

None

nemo_rl.algorithms.dpo.setup(
master_config: nemo_rl.algorithms.dpo.MasterConfig,
tokenizer: transformers.AutoTokenizer,
train_dataset: nemo_rl.data.datasets.AllTaskProcessedDataset,
val_dataset: nemo_rl.data.datasets.AllTaskProcessedDataset,
) Tuple[nemo_rl.models.policy.hf_policy.HfPolicy, nemo_rl.distributed.virtual_cluster.RayVirtualCluster, torchdata.stateful_dataloader.StatefulDataLoader, torchdata.stateful_dataloader.StatefulDataLoader, nemo_rl.algorithms.loss_functions.DPOLossFn, nemo_rl.algorithms.dpo.MasterConfig, nemo_rl.utils.logger.Logger, nemo_rl.data.interfaces.TaskDataSpec, nemo_rl.algorithms.dpo.DPOSaveState][source]#

Main entry point for running DPO algorithm.

Returns:

Tuple of policy, cluster, dataloader, tokenizer, loss_fn, math_env, master_config, logger

nemo_rl.algorithms.dpo.add_ref_logprobs_to_data(
dataloader,
policy,
master_config,
is_val=False,
)[source]#
nemo_rl.algorithms.dpo.validate(
policy: nemo_rl.models.interfaces.PolicyInterface,
val_dataloader: torchdata.stateful_dataloader.StatefulDataLoader,
tokenizer,
loss_fn,
step: int,
master_config: nemo_rl.algorithms.dpo.MasterConfig,
val_batches: int,
val_batch_size: int,
val_mbs: int,
)[source]#

Run validation on the validation dataset.

nemo_rl.algorithms.dpo.dpo_train(
policy,
train_dataloader,
val_dataloader,
tokenizer,
loss_fn,
master_config,
logger,
checkpointer,
dpo_save_state,
)[source]#