nemo_rl.algorithms.dpo#

Module Contents#

Classes#

Functions#

_default_dpo_save_state

setup

Main entry point for running DPO algorithm.

add_ref_logprobs_to_data

validate

validate_one_dataset

Run validation on one validation dataset.

dpo_train

API#

class nemo_rl.algorithms.dpo.DPOSaveState#

Bases: typing.TypedDict

epoch: int#

None

step: int#

None

total_steps: int#

None

consumed_samples: int#

None

total_valid_tokens: int#

None

nemo_rl.algorithms.dpo._default_dpo_save_state() nemo_rl.algorithms.dpo.DPOSaveState#
class nemo_rl.algorithms.dpo.DPOConfig#

Bases: typing.TypedDict

max_num_epochs: int#

None

max_num_steps: int#

None

val_period: int#

None

val_batches: int#

None

val_global_batch_size: int#

None

val_micro_batch_size: int#

None

val_at_start: bool#

None

seed: int#

None

reference_policy_kl_penalty: float#

None

preference_average_log_probs: bool#

None

sft_average_log_probs: bool#

None

preference_loss_weight: float#

None

sft_loss_weight: float#

None

class nemo_rl.algorithms.dpo.MasterConfig#

Bases: typing.TypedDict

policy: nemo_rl.models.policy.PolicyConfig#

None

data: nemo_rl.data.DataConfig#

None

dpo: nemo_rl.algorithms.dpo.DPOConfig#

None

logger: nemo_rl.utils.logger.LoggerConfig#

None

cluster: nemo_rl.distributed.virtual_cluster.ClusterConfig#

None

checkpointing: nemo_rl.utils.checkpoint.CheckpointingConfig#

None

class nemo_rl.algorithms.dpo.DPOValMetrics#

Bases: typing.TypedDict

loss: float#

None

sft_loss: float#

None

preference_loss: float#

None

accuracy: float#

None

rewards_chosen_mean: float#

None

rewards_rejected_mean: float#

None

num_valid_samples: float#

None

global_valid_seqs: float#

None

global_valid_toks: float#

None

nemo_rl.algorithms.dpo.setup(
master_config: nemo_rl.algorithms.dpo.MasterConfig,
tokenizer: transformers.AutoTokenizer,
train_dataset: nemo_rl.data.datasets.AllTaskProcessedDataset,
val_dataset: dict[str, nemo_rl.data.datasets.AllTaskProcessedDataset],
) tuple[nemo_rl.models.policy.lm_policy.Policy, nemo_rl.distributed.virtual_cluster.RayVirtualCluster, torchdata.stateful_dataloader.StatefulDataLoader, dict[str, torchdata.stateful_dataloader.StatefulDataLoader], nemo_rl.algorithms.loss_functions.DPOLossFn, nemo_rl.utils.logger.Logger, nemo_rl.utils.checkpoint.CheckpointManager, nemo_rl.algorithms.dpo.DPOSaveState, nemo_rl.algorithms.dpo.MasterConfig]#

Main entry point for running DPO algorithm.

Returns:

Tuple of policy, cluster, dataloader, tokenizer, loss_fn, math_env, master_config, logger

nemo_rl.algorithms.dpo.add_ref_logprobs_to_data(
dataloader,
policy,
master_config,
is_val=False,
)#
nemo_rl.algorithms.dpo.validate(
policy: nemo_rl.models.policy.interfaces.PolicyInterface,
val_dataloader: dict[str, torchdata.stateful_dataloader.StatefulDataLoader],
tokenizer,
loss_fn,
step: int,
master_config: nemo_rl.algorithms.dpo.MasterConfig,
val_batches: int,
val_batch_size: int,
val_mbs: int,
logger: nemo_rl.utils.logger.Logger,
)#
nemo_rl.algorithms.dpo.validate_one_dataset(
policy: nemo_rl.models.policy.interfaces.PolicyInterface,
val_dataloader: torchdata.stateful_dataloader.StatefulDataLoader,
loss_fn,
step: int,
master_config: nemo_rl.algorithms.dpo.MasterConfig,
val_batches: int,
val_batch_size: int,
val_mbs: int,
dataset_name: str,
)#

Run validation on one validation dataset.

nemo_rl.algorithms.dpo.dpo_train(
policy,
train_dataloader,
val_dataloader,
tokenizer,
loss_fn,
master_config,
logger,
checkpointer,
dpo_save_state: nemo_rl.algorithms.dpo.DPOSaveState,
) None#