# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import subprocess
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Optional, Type, Union
from invoke.context import Context
from nemo_run.config import RUNDIR_NAME
from nemo_run.core.execution.base import (
    Executor,
    ExecutorMacros,
)
from nemo_run.core.packaging.base import Packager
from nemo_run.core.packaging.git import GitArchivePackager
_SKYPILOT_AVAILABLE: bool = False
try:
    import sky
    import sky.task as skyt
    from sky import backends
    from sky.utils import status_lib
    _SKYPILOT_AVAILABLE = True
except ImportError:
    ...
logger = logging.getLogger(__name__)
[docs]
@dataclass(kw_only=True)
class SkypilotExecutor(Executor):
    """
    Dataclass to configure a Skypilot Executor.
    Some familiarity with `Skypilot <https://skypilot.readthedocs.io/en/latest/docs/index.html>`_ is necessary.
    In order to use this executor, you need to install NeMo Run
    with either `skypilot` (for only Kubernetes) or `skypilot-all` (for all clouds) optional features.
    Example:
    .. code-block:: python
        skypilot = SkypilotExecutor(
            gpus="A10G",
            gpus_per_node=devices,
            container_image="nvcr.io/nvidia/nemo:dev",
            infra="k8s/my-context",
            network_tier="best",
            cluster_name="nemo_tester",
            file_mounts={
                "nemo_run.whl": "nemo_run.whl",
                "/workspace/code": "/local/path/to/code",
            },
            storage_mounts={
                "/workspace/outputs": {
                    "name": "my-training-outputs",
                    "store": "gcs",  # or "s3", "azure", etc.
                    "mode": "MOUNT",
                    "persistent": True,
                },
                "/workspace/checkpoints": {
                    "name": "model-checkpoints",
                    "store": "s3",
                    "mode": "MOUNT",
                    "persistent": True,
                }
            },
            setup=\"\"\"
        conda deactivate
        nvidia-smi
        ls -al ./
        pip install nemo_run.whl
        cd /opt/NeMo && git pull origin main && pip install .
            \"\"\",
        )
    """
    HEAD_NODE_IP_VAR = "head_node_ip"
    NPROC_PER_NODE_VAR = "SKYPILOT_NUM_GPUS_PER_NODE"
    NUM_NODES_VAR = "num_nodes"
    NODE_RANK_VAR = "SKYPILOT_NODE_RANK"
    HET_GROUP_HOST_VAR = "het_group_host"
    container_image: Optional[str] = None
    cloud: Optional[Union[str, list[str]]] = None
    region: Optional[Union[str, list[str]]] = None
    zone: Optional[Union[str, list[str]]] = None
    gpus: Optional[Union[str, list[str]]] = None
    gpus_per_node: Optional[int] = None
    cpus: Optional[Union[int | float, list[int | float]]] = None
    memory: Optional[Union[int | float, list[int | float]]] = None
    instance_type: Optional[Union[str, list[str]]] = None
    num_nodes: int = 1
    use_spot: Optional[Union[bool, list[bool]]] = None
    disk_size: Optional[Union[int, list[int]]] = None
    disk_tier: Optional[Union[str, list[str]]] = None
    ports: Optional[tuple[str]] = None
    file_mounts: Optional[dict[str, str]] = None
    storage_mounts: Optional[dict[str, dict[str, Any]]] = None  # Can be str or dict configs
    cluster_name: Optional[str] = None
    setup: Optional[str] = None
    autodown: bool = False
    idle_minutes_to_autostop: Optional[int] = None
    torchrun_nproc_per_node: Optional[int] = None
    cluster_config_overrides: Optional[dict[str, Any]] = None
    infra: Optional[str] = None
    network_tier: Optional[str] = None
    retry_until_up: bool = False
    packager: Packager = field(default_factory=lambda: GitArchivePackager())  # type: ignore  # noqa: F821
    def __post_init__(self):
        assert _SKYPILOT_AVAILABLE, (
            'Skypilot is not installed. Please install it using `pip install "nemo_run[skypilot]"`.'
        )
        assert isinstance(self.packager, GitArchivePackager), (
            "Only GitArchivePackager is currently supported for SkypilotExecutor."
        )
        if self.infra is not None:
            assert self.cloud is None, "Cannot specify both `infra` and `cloud` parameters."
            assert self.region is None, "Cannot specify both `infra` and `region` parameters."
            assert self.zone is None, "Cannot specify both `infra` and `zone` parameters."
            logger.info(
                "`cloud` is deprecated and will be removed in a future version. Use `infra` instead."
            )
    @classmethod
    def parse_app(cls: Type["SkypilotExecutor"], app_id: str) -> tuple[str, str, int]:
        app = app_id.split("___")
        _, cluster, task, job_id = app[0], app[1], app[2], app[3]
        assert cluster and task and job_id, f"Invalid app id for Skypilot: {app_id}"
        return cluster, task, int(job_id)
    def to_resources(self) -> Union[set["sky.Resources"], set["sky.Resources"]]:
        from sky.resources import Resources
        resources_cfg = {}
        accelerators = None
        if self.gpus:
            if not self.gpus_per_node:
                self.gpus_per_node = 1
            else:
                assert isinstance(self.gpus_per_node, int)
            gpus = [self.gpus] if isinstance(self.gpus, str) else self.gpus
            accelerators = {}
            for gpu in gpus:
                accelerators[gpu] = self.gpus_per_node
            resources_cfg["accelerators"] = accelerators
        if self.container_image:
            resources_cfg["image_id"] = self.container_image
        any_of = []
        def parse_attr(attr: str):
            if getattr(self, attr, None) is not None:
                value = getattr(self, attr)
                if isinstance(value, list):
                    for i, val in enumerate(value):
                        if len(any_of) < i + 1:
                            any_of.append({})
                        if isinstance(val, str) and val.lower() == "none":
                            any_of[i][attr] = None
                        else:
                            any_of[i][attr] = val
                else:
                    if isinstance(value, str) and value.lower() == "none":
                        resources_cfg[attr] = None
                    else:
                        resources_cfg[attr] = value
        # any_of = False
        attrs = [
            "cloud",
            "region",
            "zone",
            "cpus",
            "memory",
            "instance_type",
            "use_spot",
            "infra",
            "network_tier",
            "image_id",
            "disk_size",
            "disk_tier",
            "ports",
        ]
        for attr in attrs:
            parse_attr(attr)
        resources_cfg["any_of"] = any_of
        if self.cluster_config_overrides:
            resources_cfg["_cluster_config_overrides"] = self.cluster_config_overrides
        resources = Resources.from_yaml_config(resources_cfg)
        return resources  # type: ignore
    @classmethod
    def status(
        cls: Type["SkypilotExecutor"], app_id: str
    ) -> tuple[Optional["status_lib.ClusterStatus"], Optional[dict]]:
        import sky.core as sky_core
        import sky.exceptions as sky_exceptions
        from sky.utils import status_lib
        cluster, _, job_id = cls.parse_app(app_id)
        try:
            cluster_details = sky_core.status(cluster)[0]
            cluster_status: status_lib.ClusterStatus = cluster_details["status"]
        except Exception:
            return None, None
        try:
            job_queue = sky_core.queue(cluster, all_users=True)
            job_details = next(filter(lambda job: job["job_id"] == job_id, job_queue))
        except sky_exceptions.ClusterNotUpError:
            return cluster_status, None
        return cluster_status, job_details
    @classmethod
    def cancel(cls: Type["SkypilotExecutor"], app_id: str):
        from sky.core import cancel
        cluster_name, _, job_id = cls.parse_app(app_id=app_id)
        _, job_details = cls.status(app_id=app_id)
        if not job_details:
            return
        cancel(cluster_name=cluster_name, job_ids=[job_id])
    @classmethod
    def logs(cls: Type["SkypilotExecutor"], app_id: str, fallback_path: Optional[str]):
        import sky.core as sky_core
        from sky.skylet import job_lib
        cluster, _, job_id = cls.parse_app(app_id)
        _, job_details = cls.status(app_id)
        is_terminal = False
        if job_details and job_lib.JobStatus.is_terminal(job_details["status"]):
            is_terminal = True
        elif not job_details:
            is_terminal = True
        if fallback_path and is_terminal:
            log_path = os.path.expanduser(os.path.join(fallback_path, "run.log"))
            if os.path.isfile(log_path):
                with open(os.path.expanduser(os.path.join(fallback_path, "run.log"))) as f:
                    for line in f:
                        print(line, end="", flush=True)
                return
        sky_core.tail_logs(cluster, job_id)
    @property
    def workdir(self) -> str:
        return os.path.join(f"{self.job_dir}", "workdir")
    def package_configs(self, *cfgs: tuple[str, str]) -> list[str]:
        filenames = []
        basepath = os.path.join(self.job_dir, "configs")
        for name, cfg in cfgs:
            filename = os.path.join(basepath, name)
            os.makedirs(os.path.dirname(filename), exist_ok=True)
            with open(filename, "w") as f:
                f.write(cfg)
            filenames.append(
                os.path.join(
                    "/",
                    RUNDIR_NAME,
                    "configs",
                    name,
                )
            )
        return filenames
[docs]
    def assign(
        self,
        exp_id: str,
        exp_dir: str,
        task_id: str,
        task_dir: str,
    ):
        self.job_name = task_id
        self.experiment_dir = exp_dir
        self.job_dir = os.path.join(exp_dir, task_dir)
        self.experiment_id = exp_id 
    def package(self, packager: Packager, job_name: str):
        assert self.experiment_id, "Executor not assigned to an experiment."
        if isinstance(packager, GitArchivePackager):
            output = subprocess.run(
                ["git", "rev-parse", "--show-toplevel"],
                check=True,
                stdout=subprocess.PIPE,
            )
            path = output.stdout.splitlines()[0].decode()
            base_path = Path(path).absolute()
        else:
            base_path = Path(os.getcwd()).absolute()
        local_pkg = packager.package(base_path, self.job_dir, job_name)
        local_code_extraction_path = os.path.join(self.job_dir, "code")
        ctx = Context()
        ctx.run(f"mkdir -p {local_code_extraction_path}")
        if self.get_launcher().nsys_profile:
            remote_nsys_extraction_path = os.path.join(
                self.job_dir, self.get_launcher().nsys_folder
            )
            ctx.run(f"mkdir -p {remote_nsys_extraction_path}")
        if local_pkg:
            ctx.run(
                f"tar -xvzf {local_pkg} -C {local_code_extraction_path} --ignore-zeros", hide=True
            )
[docs]
    def nnodes(self) -> int:
        return self.num_nodes 
[docs]
    def nproc_per_node(self) -> int:
        if self.torchrun_nproc_per_node:
            return self.torchrun_nproc_per_node
        return self.gpus_per_node or 1 
[docs]
    def macro_values(self) -> Optional[ExecutorMacros]:
        return ExecutorMacros(
            head_node_ip_var=self.HEAD_NODE_IP_VAR,
            nproc_per_node_var=self.NPROC_PER_NODE_VAR,
            num_nodes_var=self.NUM_NODES_VAR,
            node_rank_var=self.NODE_RANK_VAR,
            het_group_host_var=self.HET_GROUP_HOST_VAR,
        ) 
    def to_task(
        self,
        name: str,
        cmd: Optional[list[str]] = None,
        env_vars: Optional[dict[str, str]] = None,
    ) -> "skyt.Task":
        from sky.task import Task
        run_cmd = None
        if cmd:
            run_cmd = f"""
conda deactivate
num_nodes=`echo "$SKYPILOT_NODE_IPS" | wc -l`
echo "num_nodes=$num_nodes"
head_node_ip=`echo "$SKYPILOT_NODE_IPS" | head -n1`
echo "head_node_ip=$head_node_ip"
cd /nemo_run/code
{" ".join(cmd)}
"""
        task = Task(
            name=name,
            setup=self.setup if self.setup else "",
            run=run_cmd,
            envs=self.env_vars,
            num_nodes=self.num_nodes,
        )
        # Handle regular file mounts
        file_mounts = self.file_mounts or {}
        file_mounts["/nemo_run"] = self.job_dir
        task.set_file_mounts(file_mounts)
        # Handle storage mounts separately
        if self.storage_mounts:
            from sky.data import Storage
            storage_objects = {}
            for mount_path, config in self.storage_mounts.items():
                # Create Storage object from config dict
                storage_obj = Storage.from_yaml_config(config)
                storage_objects[mount_path] = storage_obj
            task.set_storage_mounts(storage_objects)
        task.set_resources(self.to_resources())
        if env_vars:
            task.update_envs(env_vars)
        return task
    def launch(
        self,
        task: "skyt.Task",
        cluster_name: Optional[str] = None,
        num_nodes: Optional[int] = None,
        dryrun: bool = False,
    ) -> tuple[Optional[int], Optional["backends.ResourceHandle"]]:
        from sky import backends, launch, stream_and_get
        # Backward compatibility for SkyPilot 0.10.3+
        # dump_yaml_str moved from sky.utils.common_utils to yaml_utils
        try:
            from sky.utils import yaml_utils
        except ImportError:
            from sky.utils import common_utils as yaml_utils
        task_yml = os.path.join(self.job_dir, "skypilot_task.yml")
        with open(task_yml, "w+") as f:
            f.write(yaml_utils.dump_yaml_str(task.to_yaml_config()))
        backend = backends.CloudVmRayBackend()
        if num_nodes:
            task.num_nodes = num_nodes
        cluster_name = cluster_name or self.cluster_name or self.experiment_id
        job_id, handle = stream_and_get(
            launch(
                task,
                dryrun=dryrun,
                cluster_name=cluster_name,
                backend=backend,
                idle_minutes_to_autostop=self.idle_minutes_to_autostop,
                down=self.autodown,
                fast=True,
                retry_until_up=self.retry_until_up,
                # clone_disk_from=clone_disk_from,
            )
        )
        return job_id, handle
    def cleanup(self, handle: str):
        import sky.core as sky_core
        _, _, path_str = handle.partition("://")
        path = path_str.split("/")
        app_id = path[1]
        cluster, _, job_id = self.parse_app(app_id)
        sky_core.download_logs(cluster, job_ids=[job_id])