nemo_rl.algorithms.sft
#
Module Contents#
Classes#
Functions#
Main entry point for running SFT algorithm. |
|
Run validation on the validation dataset. |
|
API#
- class nemo_rl.algorithms.sft.SFTSaveState[source]#
Bases:
typing.TypedDict
- epoch: int#
None
- step: int#
None
- total_steps: int#
None
- val_loss: float#
None
- consumed_samples: int#
None
- nemo_rl.algorithms.sft._default_sft_save_state() nemo_rl.algorithms.sft.SFTSaveState [source]#
- class nemo_rl.algorithms.sft.SFTConfig[source]#
Bases:
typing.TypedDict
- max_num_steps: int#
None
- max_num_epochs: int#
None
- val_period: int#
None
- val_batches: int#
None
- val_global_batch_size: int#
None
- val_micro_batch_size: int#
None
- val_at_start: bool#
None
- seed: int#
None
- class nemo_rl.algorithms.sft.MasterConfig[source]#
Bases:
typing.TypedDict
- policy: nemo_rl.models.policy.PolicyConfig#
None
- data: nemo_rl.data.DataConfig#
None
- sft: nemo_rl.algorithms.sft.SFTConfig#
None
- logger: nemo_rl.utils.logger.LoggerConfig#
None
- cluster: nemo_rl.distributed.virtual_cluster.ClusterConfig#
None
- checkpointing: nemo_rl.utils.checkpoint.CheckpointingConfig#
None
- nemo_rl.algorithms.sft.setup(
- master_config: nemo_rl.algorithms.sft.MasterConfig,
- tokenizer: transformers.AutoTokenizer,
- train_dataset: nemo_rl.data.datasets.AllTaskProcessedDataset,
- val_dataset: nemo_rl.data.datasets.AllTaskProcessedDataset,
Main entry point for running SFT algorithm.
- Returns:
Tuple of policy, cluster, dataloader, tokenizer, loss_fn, math_env, master_config, logger
- nemo_rl.algorithms.sft.validate(
- policy: nemo_rl.models.interfaces.PolicyInterface,
- val_dataloader: torchdata.stateful_dataloader.StatefulDataLoader,
- tokenizer,
- loss_fn,
- step: int,
- master_config: nemo_rl.algorithms.sft.MasterConfig,
- sft_task_spec: nemo_rl.data.interfaces.TaskDataSpec,
- val_batches: int,
- val_batch_size: int,
- val_mbs: int,
Run validation on the validation dataset.