nat.plugins.openpipe.trainer_adapter#
Attributes#
Classes#
Adapter for the ART Trainer backend. |
Module Contents#
- logger#
- class ARTTrainerAdapter(
- adapter_config: nat.plugins.openpipe.config.ARTTrainerAdapterConfig,
Bases:
nat.finetuning.interfaces.trainer_adapter.TrainerAdapterAdapter for the ART Trainer backend.
- adapter_config: nat.plugins.openpipe.config.ARTTrainerAdapterConfig#
- remote_backend: art.Backend#
- _model_internal_config: art.dev.InternalModelConfig#
- model: art.TrainableModel#
- _training_jobs: dict[str, asyncio.Task[None]]#
- property training_jobs: dict[str, asyncio.Task[None]]#
- async initialize(
- run_config: nat.data_models.finetuning.FinetuneConfig,
Asynchronously initialize any resources needed for the trainer adapter.
- async is_healthy() bool#
Check the health of the remote training backend.
- Returns:
bool: True if the backend is healthy, False otherwise.
- async _validate_episode_order(traj: nat.data_models.finetuning.Trajectory)#
Checks all EpisodeItem in traj.episode to validate:
Every EpisodeItem.role is EpisodeItemRole.USER, SYSTEM, or ASSISTANT
The first EpisodeItem.role is SYSTEM or USER
The last EpisodeItem.role is ASSISTANT
No two consecutive EpisodeItem.role are the same, except for SYSTEM
- Args:
traj: Trajectory to validate
- Raises:
ValueError: If any of the above conditions are not met.
- async _construct_trajectory_groups(
- trajectory_lists: list[list[nat.data_models.finetuning.Trajectory]],
Convert list of lists of NAT Trajectory to list of ART TrajectoryGroup.
- Args:
- trajectory_lists: List of lists of NAT Trajectory (each inner list
contains trajectories for one example).
- Returns:
List of ART TrajectoryGroup.
- Raises:
ValueError: If any trajectory is invalid.
- async submit(
- trajectories: nat.data_models.finetuning.TrajectoryCollection,
Submit trajectories to ART backend for training.
- Args:
trajectories: TrajectoryCollection with list of lists of NAT Trajectory.
- Returns:
TrainingJobRef: Reference to the submitted training job.
- async status( ) nat.data_models.finetuning.TrainingJobStatus#
Get the status of a submitted training job.
- Args:
ref (TrainingJobRef): Reference to the training job.
- Returns:
TrainingJobStatus: The current status of the training job.
- async wait_until_complete(
- ref: nat.data_models.finetuning.TrainingJobRef,
- poll_interval: float = 10.0,
Wait until the training job is complete.
- Args:
ref (TrainingJobRef): Reference to the training job. poll_interval (float): Time in seconds between status checks.
- Returns:
TrainingJobStatus: The final status of the training job.