nat.plugins.openpipe.trainer_adapter#

Attributes#

Classes#

ARTTrainerAdapter

Adapter for the ART Trainer backend.

Module Contents#

logger#
class ARTTrainerAdapter(
adapter_config: nat.plugins.openpipe.config.ARTTrainerAdapterConfig,
)#

Bases: nat.finetuning.interfaces.trainer_adapter.TrainerAdapter

Adapter for the ART Trainer backend.

adapter_config: nat.plugins.openpipe.config.ARTTrainerAdapterConfig#
remote_backend: art.Backend#
_model_internal_config: art.dev.InternalModelConfig#
model: art.TrainableModel#
_training_jobs: dict[str, asyncio.Task[None]]#
property training_jobs: dict[str, asyncio.Task[None]]#
async initialize(
run_config: nat.data_models.finetuning.FinetuneConfig,
) None#

Asynchronously initialize any resources needed for the trainer adapter.

async is_healthy() bool#

Check the health of the remote training backend.

Returns:

bool: True if the backend is healthy, False otherwise.

async _validate_episode_order(traj: nat.data_models.finetuning.Trajectory)#

Checks all EpisodeItem in traj.episode to validate:

  • Every EpisodeItem.role is EpisodeItemRole.USER, SYSTEM, or ASSISTANT

  • The first EpisodeItem.role is SYSTEM or USER

  • The last EpisodeItem.role is ASSISTANT

  • No two consecutive EpisodeItem.role are the same, except for SYSTEM

Args:

traj: Trajectory to validate

Raises:

ValueError: If any of the above conditions are not met.

async _construct_trajectory_groups(
trajectory_lists: list[list[nat.data_models.finetuning.Trajectory]],
) list[art.TrajectoryGroup]#

Convert list of lists of NAT Trajectory to list of ART TrajectoryGroup.

Args:
trajectory_lists: List of lists of NAT Trajectory (each inner list

contains trajectories for one example).

Returns:

List of ART TrajectoryGroup.

Raises:

ValueError: If any trajectory is invalid.

async submit(
trajectories: nat.data_models.finetuning.TrajectoryCollection,
) nat.data_models.finetuning.TrainingJobRef#

Submit trajectories to ART backend for training.

Args:

trajectories: TrajectoryCollection with list of lists of NAT Trajectory.

Returns:

TrainingJobRef: Reference to the submitted training job.

async status(
ref: nat.data_models.finetuning.TrainingJobRef,
) nat.data_models.finetuning.TrainingJobStatus#

Get the status of a submitted training job.

Args:

ref (TrainingJobRef): Reference to the training job.

Returns:

TrainingJobStatus: The current status of the training job.

async wait_until_complete(
ref: nat.data_models.finetuning.TrainingJobRef,
poll_interval: float = 10.0,
) nat.data_models.finetuning.TrainingJobStatus#

Wait until the training job is complete.

Args:

ref (TrainingJobRef): Reference to the training job. poll_interval (float): Time in seconds between status checks.

Returns:

TrainingJobStatus: The final status of the training job.

log_progress(
ref: nat.data_models.finetuning.TrainingJobRef,
metrics: dict[str, Any],
output_dir: str | None = None,
) None#

Log training adapter progress.

Args:

ref: Training job reference metrics: Dictionary of metrics to log output_dir: Optional output directory override