Source code for air_sdk.endpoints.checkpoints

# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: MIT
from __future__ import annotations

from dataclasses import dataclass, field
from datetime import datetime
from typing import Any

from air_sdk.air_model import (
    AirModel,
    BaseEndpointAPI,
    PrimaryKey,
)
from air_sdk.endpoints import mixins
from air_sdk.utils import validate_payload_types


[docs] @dataclass(eq=False) class Checkpoint(AirModel): id: str = field(repr=False) name: str favorite: bool created: datetime = field(repr=False) modified: datetime = field(repr=False) run: str = field(repr=False) state: str = field(repr=False)
[docs] @classmethod def get_model_api(cls) -> type[CheckpointEndpointAPI]: return CheckpointEndpointAPI
@property def model_api(self) -> CheckpointEndpointAPI: return self.get_model_api()(self.__api__)
[docs] def update(self, **kwargs: Any) -> None: self.model_api.update(checkpoint=self, **kwargs)
[docs] class CheckpointEndpointAPI( mixins.ListApiMixin[Checkpoint], mixins.GetApiMixin[Checkpoint], mixins.PatchApiMixin[Checkpoint], mixins.DeleteApiMixin, BaseEndpointAPI[Checkpoint], ): API_PATH = 'simulations/runs/checkpoints' model = Checkpoint
[docs] @validate_payload_types def update(self, *, checkpoint: Checkpoint | PrimaryKey, **kwargs: Any) -> Checkpoint: checkpoint_id = ( checkpoint.id if isinstance(checkpoint, Checkpoint) else checkpoint ) result = self.patch(checkpoint_id, **kwargs) if isinstance(checkpoint, Checkpoint): checkpoint.__refresh__(refreshed_obj=result) return result