nat.plugins.openpipe.remote_backend#

Attributes#

Classes#

RemoteBackend

HTTP client that satisfies the art.Backend Protocol against art run.

Functions#

_dump_model(→ dict)

Module Contents#

logger#
_DEFAULT_TIMEOUT#
class RemoteBackend(base_url: str)#

HTTP client that satisfies the art.Backend Protocol against art run.

base_url#
_client#
_latest_step: dict[str, int]#
_model_inference_name(
model: art.backend.AnyModel,
step: int | None = None,
) str#
async close() None#
async register(model: art.backend.AnyModel) None#
async _get_step(model: art.backend.AnyTrainableModel) int#
async _delete_checkpoint_files(
model: art.backend.AnyTrainableModel,
steps_to_keep: list[int],
) None#
async _prepare_backend_for_training(
model: art.backend.AnyTrainableModel,
config: art.dev.OpenAIServerConfig | None,
) tuple[str, str]#
async _train_model(
model: art.backend.AnyTrainableModel,
trajectory_groups: list[art.trajectories.TrajectoryGroup],
config: art.types.TrainConfig,
dev_config: art.dev.TrainConfig,
verbose: bool = False,
) collections.abc.AsyncIterator[dict[str, float]]#
async train(model: art.backend.AnyTrainableModel, trajectory_groups: collections.abc.Iterable[art.trajectories.TrajectoryGroup], \*\*kwargs: Any) art.types.TrainResult#
abstractmethod _train_sft(
model: art.backend.AnyTrainableModel,
trajectories: collections.abc.Iterable[art.trajectories.Trajectory],
config: art.types.TrainSFTConfig,
dev_config: art.dev.TrainSFTConfig,
verbose: bool = False,
) collections.abc.AsyncIterator[dict[str, float]]#
_dump_model(model: art.backend.AnyModel) dict#