nat.finetuning.interfaces.finetuning_runner#
Attributes#
Classes#
Abstract interface for running finetuning workflows. |
Module Contents#
- logger#
- class Trainer(
- trainer_config: nat.data_models.finetuning.TrainerConfig,
- \*\*kwargs,
Bases:
abc.ABCAbstract interface for running finetuning workflows.
The Trainer orchestrates the entire finetuning process by: 1. Running evaluations to generate trajectories via TrajectoryBuilder 2. Submitting trajectories for training via TrainerAdapter 3. Managing multiple epochs of training
Initialize the Trainer.
- Args:
trainer_config: Configuration for the trainer backend run_config: Configuration for the training run backend: Backend identifier curriculum_config: Optional curriculum learning configuration
- trainer_config#
- run_config: nat.data_models.finetuning.FinetuneConfig = None#
- curriculum_config = None#
- trajectory_builder: nat.finetuning.interfaces.trajectory_builder.TrajectoryBuilder = None#
- trainer_adapter: nat.finetuning.interfaces.trainer_adapter.TrainerAdapter = None#
- _curriculum_state = None#
- async bind_components(
- trajectory_builder: nat.finetuning.interfaces.trajectory_builder.TrajectoryBuilder,
- trainer_adapter: nat.finetuning.interfaces.trainer_adapter.TrainerAdapter,
Bind the TrajectoryBuilder and TrainerAdapter components.
- Args:
trajectory_builder: Instance of TrajectoryBuilder trainer_adapter: Instance of TrainerAdapter
- async initialize(
- run_config: nat.data_models.finetuning.FinetuneConfig,
Initialize the runner and its components.
This should: - Initialize the TrajectoryBuilder - Initialize the TrainerAdapter - Verify connectivity to backend services
- abstractmethod run_epoch( ) nat.data_models.finetuning.TrainingJobRef#
- Async:
Run a single epoch of training.
- Args:
epoch: The current epoch number (0-indexed) run_id: Unique identifier for this training run
- Returns:
TrainingJobRef: Reference to the submitted training job
- abstractmethod run(
- num_epochs: int,
- Async:
Run the complete finetuning workflow for the specified number of epochs.
- Args:
num_epochs: Number of epochs to train
- Returns:
list[TrainingJobStatus]: Status of all training jobs
- abstractmethod get_metrics(run_id: str) dict[str, Any]#
- Async:
Get training metrics for a specific run.
- Args:
run_id: The run identifier
- Returns:
dict: Metrics from the training run
- abstractmethod log_progress( ) None#
Log training progress for monitoring.
- Args:
epoch: Current epoch number metrics: Dictionary of metrics to log output_dir: Optional output directory override
- async run_validation_evaluation(epoch: int, run_id: str) dict[str, Any]#
Run evaluation on validation dataset to collect rewards.
This method creates a temporary TrainerRunConfig with the validation dataset and runs evaluation to collect rewards without training.
- Args:
epoch: Current epoch number run_id: Unique identifier for this training run validation_dataset: Path to the validation dataset
- Returns:
dict: Validation metrics including average reward
- _calculate_validation_metrics(
- eval_output: nat.eval.config.EvaluationRunOutput,
Calculate validation metrics from evaluation output.
- Args:
eval_output: Output from evaluation run
- Returns:
dict: Calculated metrics
- abstractmethod apply_curriculum_learning(
- trajectory_collection: nat.data_models.finetuning.TrajectoryCollection,
- epoch: int,
Apply curriculum learning to filter trajectory groups based on difficulty.