nat.finetuning#
Submodules#
Classes#
Abstract interface for running finetuning workflows. |
|
Adapter to send Trajectories to remote training cluster for weights updates. |
|
Abstract interface for building trajectories from episode items. |
Package Contents#
- class Trainer(
- trainer_config: nat.data_models.finetuning.TrainerConfig,
- \*\*kwargs,
Bases:
abc.ABCAbstract interface for running finetuning workflows.
The Trainer orchestrates the entire finetuning process by: 1. Running evaluations to generate trajectories via TrajectoryBuilder 2. Submitting trajectories for training via TrainerAdapter 3. Managing multiple epochs of training
Initialize the Trainer.
- Args:
trainer_config: Configuration for the trainer backend run_config: Configuration for the training run backend: Backend identifier curriculum_config: Optional curriculum learning configuration
- trainer_config#
- run_config: nat.data_models.finetuning.FinetuneConfig = None#
- curriculum_config = None#
- trajectory_builder: nat.finetuning.interfaces.trajectory_builder.TrajectoryBuilder = None#
- trainer_adapter: nat.finetuning.interfaces.trainer_adapter.TrainerAdapter = None#
- _curriculum_state = None#
- async bind_components(
- trajectory_builder: nat.finetuning.interfaces.trajectory_builder.TrajectoryBuilder,
- trainer_adapter: nat.finetuning.interfaces.trainer_adapter.TrainerAdapter,
Bind the TrajectoryBuilder and TrainerAdapter components.
- Args:
trajectory_builder: Instance of TrajectoryBuilder trainer_adapter: Instance of TrainerAdapter
- async initialize(
- run_config: nat.data_models.finetuning.FinetuneConfig,
Initialize the runner and its components.
This should: - Initialize the TrajectoryBuilder - Initialize the TrainerAdapter - Verify connectivity to backend services
- abstractmethod run_epoch( ) nat.data_models.finetuning.TrainingJobRef#
- Async:
Run a single epoch of training.
- Args:
epoch: The current epoch number (0-indexed) run_id: Unique identifier for this training run
- Returns:
TrainingJobRef: Reference to the submitted training job
- abstractmethod run(
- num_epochs: int,
- Async:
Run the complete finetuning workflow for the specified number of epochs.
- Args:
num_epochs: Number of epochs to train
- Returns:
list[TrainingJobStatus]: Status of all training jobs
- abstractmethod get_metrics(run_id: str) dict[str, Any]#
- Async:
Get training metrics for a specific run.
- Args:
run_id: The run identifier
- Returns:
dict: Metrics from the training run
- abstractmethod log_progress( ) None#
Log training progress for monitoring.
- Args:
epoch: Current epoch number metrics: Dictionary of metrics to log output_dir: Optional output directory override
- async run_validation_evaluation(epoch: int, run_id: str) dict[str, Any]#
Run evaluation on validation dataset to collect rewards.
This method creates a temporary TrainerRunConfig with the validation dataset and runs evaluation to collect rewards without training.
- Args:
epoch: Current epoch number run_id: Unique identifier for this training run validation_dataset: Path to the validation dataset
- Returns:
dict: Validation metrics including average reward
- _calculate_validation_metrics(
- eval_output: nat.eval.config.EvaluationRunOutput,
Calculate validation metrics from evaluation output.
- Args:
eval_output: Output from evaluation run
- Returns:
dict: Calculated metrics
- abstractmethod apply_curriculum_learning(
- trajectory_collection: nat.data_models.finetuning.TrajectoryCollection,
- epoch: int,
Apply curriculum learning to filter trajectory groups based on difficulty.
- class TrainerAdapter(
- adapter_config: nat.data_models.finetuning.TrainerAdapterConfig,
Bases:
abc.ABCAdapter to send Trajectories to remote training cluster for weights updates.
- adapter_config#
- run_config: nat.data_models.finetuning.FinetuneConfig = None#
- async initialize(
- run_config: nat.data_models.finetuning.FinetuneConfig,
Asynchronously initialize any resources needed for the trainer adapter.
- abstractmethod is_healthy() bool#
- Async:
Check the health of the remote training backend.
- Returns:
bool: True if the backend is healthy, False otherwise.
- abstractmethod submit(
- trajectories: nat.data_models.finetuning.TrajectoryCollection,
- Async:
Submit trajectories to remote training backend.
- Args:
trajectories (list[Trajectory]): The list of trajectories to submit.
- Returns:
TrainingJobRef: Reference to the submitted training job.
- abstractmethod status( ) nat.data_models.finetuning.TrainingJobStatus#
- Async:
Get the status of a submitted training job.
- Args:
ref (TrainingJobRef): Reference to the training job.
- Returns:
TrainingJobStatus: The current status of the training job.
- abstractmethod wait_until_complete(
- ref: nat.data_models.finetuning.TrainingJobRef,
- poll_interval: float = 10.0,
- Async:
Wait until the training job is complete.
- Args:
ref (TrainingJobRef): Reference to the training job. poll_interval (float): Time in seconds between status checks.
- Returns:
TrainingJobStatus: The final status of the training job.
- class TrajectoryBuilder(
- trajectory_builder_config: nat.data_models.finetuning.TrajectoryBuilderConfig,
Bases:
abc.ABCAbstract interface for building trajectories from episode items.
- trajectory_builder_config#
- run_config: nat.data_models.finetuning.FinetuneConfig = None#
- async initialize(
- run_config: nat.data_models.finetuning.FinetuneConfig,
Asynchronously initialize any resources needed for the trajectory builder.
- async run_eval() nat.eval.config.EvaluationRunOutput#
Run NAT Evaluation to generate episode items for trajectory building.
- Returns:
EvaluationRunOutput: The output of the evaluation run.
- abstractmethod start_run(run_id: str, meta: dict | None = None) None#
- Async:
Initialize any resources needed for the trajectory builder.
- Args:
run_id (str): The unique identifier for the training run. meta (dict): Metadata associated with the training run.
- abstractmethod finalize( ) nat.data_models.finetuning.TrajectoryCollection#
- Async:
Finalize the trajectory building process and return the constructed trajectories.
- Args:
run_id (str): The unique identifier for the training run. meta (dict): Metadata associated with the training run.
- Returns:
list[Trajectory]: The list of constructed trajectories.
- async compute_reward(
- output_item: nat.eval.evaluator.evaluator_model.EvalOutputItem,
- meta: dict | None = None,
Compute reward for a given EvalOutputItem.
- Args:
output_item (EvalOutputItem): The evaluation output item. meta (dict): Metadata associated with the training run.
- Returns:
float: The computed reward.