nemo_rl.algorithms.rm#

Module Contents#

Classes#

Functions#

_default_rm_save_state

setup

Main entry point for running RM algorithm.

validate

Run validation on the validation dataset.

rm_train

API#

class nemo_rl.algorithms.rm.RMSaveState#

Bases: typing.TypedDict

epoch: int#

None

step: int#

None

total_steps: int#

None

val_loss: float#

None

consumed_samples: int#

None

nemo_rl.algorithms.rm._default_rm_save_state() nemo_rl.algorithms.rm.RMSaveState#
class nemo_rl.algorithms.rm.RMConfig#

Bases: typing.TypedDict

max_num_steps: int#

None

max_num_epochs: int#

None

val_period: int#

None

val_batches: int#

None

val_global_batch_size: int#

None

val_micro_batch_size: int#

None

val_at_start: bool#

None

seed: int#

None

class nemo_rl.algorithms.rm.MasterConfig#

Bases: typing.TypedDict

policy: nemo_rl.models.policy.PolicyConfig#

None

data: nemo_rl.data.DataConfig#

None

rm: nemo_rl.algorithms.rm.RMConfig#

None

logger: nemo_rl.utils.logger.LoggerConfig#

None

cluster: nemo_rl.distributed.virtual_cluster.ClusterConfig#

None

checkpointing: nemo_rl.utils.checkpoint.CheckpointingConfig#

None

class nemo_rl.algorithms.rm.RMValMetrics#

Bases: typing.TypedDict

val_loss: float#

None

accuracy: float#

None

rewards_chosen_mean: float#

None

rewards_rejected_mean: float#

None

num_valid_samples: float#

None

nemo_rl.algorithms.rm.setup(
master_config: nemo_rl.algorithms.rm.MasterConfig,
tokenizer: transformers.AutoTokenizer,
train_dataset: nemo_rl.data.datasets.AllTaskProcessedDataset,
val_dataset: nemo_rl.data.datasets.AllTaskProcessedDataset,
) tuple[nemo_rl.models.policy.lm_policy.Policy, nemo_rl.distributed.virtual_cluster.RayVirtualCluster, torchdata.stateful_dataloader.StatefulDataLoader, torchdata.stateful_dataloader.StatefulDataLoader, nemo_rl.algorithms.loss_functions.PreferenceLoss, nemo_rl.algorithms.rm.MasterConfig, nemo_rl.utils.logger.Logger, nemo_rl.data.interfaces.TaskDataSpec, nemo_rl.algorithms.rm.RMSaveState]#

Main entry point for running RM algorithm.

Returns:

Tuple of policy, cluster, dataloader, tokenizer, loss_fn, math_env, master_config, logger

nemo_rl.algorithms.rm.validate(
policy: nemo_rl.models.policy.interfaces.PolicyInterface,
val_dataloader: torchdata.stateful_dataloader.StatefulDataLoader,
tokenizer,
loss_fn,
step: int,
master_config: nemo_rl.algorithms.rm.MasterConfig,
rm_task_spec: nemo_rl.data.interfaces.TaskDataSpec,
val_batches: int,
val_batch_size: int,
val_mbs: int,
)#

Run validation on the validation dataset.

nemo_rl.algorithms.rm.rm_train(
policy,
train_dataloader,
val_dataloader,
tokenizer,
loss_fn,
master_config,
logger,
rm_task_spec,
checkpointer,
rm_save_state,
)#