nat.plugins.customizer.dpo.trainer#
NeMo Customizer Trainer for DPO finetuning.
This module provides a Trainer implementation that orchestrates data collection via trajectory builders and submits training jobs to NeMo Customizer.
Attributes#
Classes#
Trainer for NeMo Customizer DPO/SFT finetuning. |
Module Contents#
- logger#
- class NeMoCustomizerTrainer(
- trainer_config: nat.plugins.customizer.dpo.config.NeMoCustomizerTrainerConfig,
- \*\*kwargs,
Bases:
nat.finetuning.interfaces.finetuning_runner.TrainerTrainer for NeMo Customizer DPO/SFT finetuning.
Unlike epoch-based trainers, this trainer: 1. Runs the trajectory builder multiple times (num_runs) to collect data 2. Aggregates all trajectories into a single dataset 3. Submits the dataset to NeMo Customizer for training 4. Monitors the training job until completion
The actual training epochs are handled by NeMo Customizer via hyperparameters.
Initialize the NeMo Customizer Trainer.
- Args:
trainer_config: Configuration for the trainer
- trainer_config: nat.plugins.customizer.dpo.config.NeMoCustomizerTrainerConfig#
- _job_ref: nat.data_models.finetuning.TrainingJobRef | None = None#
- _all_trajectories: list[list[nat.data_models.finetuning.Trajectory]] = []#
- async initialize(
- run_config: nat.data_models.finetuning.FinetuneConfig,
Initialize the trainer and its components.
Note: Curriculum learning is not supported for DPO training.
- async run_epoch( ) nat.data_models.finetuning.TrainingJobRef | None#
Run a single data collection run.
For NeMo Customizer, this collects trajectories without submitting to training. The actual submission happens in run().
- Args:
epoch: The current run number (0-indexed) run_id: Unique identifier for this training run
- Returns:
None (trajectories are accumulated, not submitted per-run)
- async run(
- num_epochs: int,
Run the complete DPO data collection and training workflow.
- Args:
num_epochs: Ignored for NeMo Customizer (uses trainer_config.num_runs)
- Returns:
list[TrainingJobStatus]: Status of the training job
- _deduplicate_trajectories(
- collection: nat.data_models.finetuning.TrajectoryCollection,
Remove duplicate DPO pairs based on prompt+responses.
- _sample_trajectories(
- collection: nat.data_models.finetuning.TrajectoryCollection,
- max_pairs: int,
Sample trajectories to limit dataset size.
- _log_final_metrics(
- final_status: nat.data_models.finetuning.TrainingJobStatus,
Log final training metrics.