nemo_rl.algorithms.sft#

Module Contents#

Classes#

Functions#

_default_sft_save_state

setup

Main entry point for running SFT algorithm.

validate

Run validation on the validation dataset.

sft_train

API#

class nemo_rl.algorithms.sft.SFTSaveState[source]#

Bases: typing.TypedDict

epoch: int#

None

step: int#

None

total_steps: int#

None

val_loss: float#

None

consumed_samples: int#

None

nemo_rl.algorithms.sft._default_sft_save_state() nemo_rl.algorithms.sft.SFTSaveState[source]#
class nemo_rl.algorithms.sft.SFTConfig[source]#

Bases: typing.TypedDict

max_num_steps: int#

None

max_num_epochs: int#

None

val_period: int#

None

val_batches: int#

None

val_global_batch_size: int#

None

val_micro_batch_size: int#

None

val_at_start: bool#

None

seed: int#

None

class nemo_rl.algorithms.sft.MasterConfig[source]#

Bases: typing.TypedDict

policy: nemo_rl.models.policy.PolicyConfig#

None

data: nemo_rl.data.DataConfig#

None

sft: nemo_rl.algorithms.sft.SFTConfig#

None

logger: nemo_rl.utils.logger.LoggerConfig#

None

cluster: nemo_rl.distributed.virtual_cluster.ClusterConfig#

None

checkpointing: nemo_rl.utils.checkpoint.CheckpointingConfig#

None

nemo_rl.algorithms.sft.setup(
master_config: nemo_rl.algorithms.sft.MasterConfig,
tokenizer: transformers.AutoTokenizer,
train_dataset: nemo_rl.data.datasets.AllTaskProcessedDataset,
val_dataset: nemo_rl.data.datasets.AllTaskProcessedDataset,
) Tuple[nemo_rl.models.policy.hf_policy.HfPolicy, nemo_rl.distributed.virtual_cluster.RayVirtualCluster, torchdata.stateful_dataloader.StatefulDataLoader, torchdata.stateful_dataloader.StatefulDataLoader, nemo_rl.algorithms.loss_functions.NLLLoss, nemo_rl.algorithms.sft.MasterConfig, nemo_rl.utils.logger.Logger, nemo_rl.data.interfaces.TaskDataSpec, nemo_rl.algorithms.sft.SFTSaveState][source]#

Main entry point for running SFT algorithm.

Returns:

Tuple of policy, cluster, dataloader, tokenizer, loss_fn, math_env, master_config, logger

nemo_rl.algorithms.sft.validate(
policy: nemo_rl.models.interfaces.PolicyInterface,
val_dataloader: torchdata.stateful_dataloader.StatefulDataLoader,
tokenizer,
loss_fn,
step: int,
master_config: nemo_rl.algorithms.sft.MasterConfig,
sft_task_spec: nemo_rl.data.interfaces.TaskDataSpec,
val_batches: int,
val_batch_size: int,
val_mbs: int,
)[source]#

Run validation on the validation dataset.

nemo_rl.algorithms.sft.sft_train(
policy,
train_dataloader,
val_dataloader,
tokenizer,
loss_fn,
master_config,
logger,
sft_task_spec,
checkpointer,
sft_save_state,
)[source]#