Source code for nemo_automodel.components.launcher.skypilot.utils
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
import os
from nemo_automodel.components.launcher.skypilot.config import SkyPilotConfig
# Fixed remote path where the job config YAML is uploaded.
REMOTE_CONFIG_PATH = "/tmp/automodel_job_config.yaml"
# Default setup command: install the package from the synced workdir.
_DEFAULT_SETUP = "cd ~/sky_workdir && pip install -e . --quiet"
_CLOUD_CLASSES = {
"aws": "AWS",
"gcp": "GCP",
"azure": "Azure",
"lambda": "Lambda",
"kubernetes": "Kubernetes",
}
[docs]
def _get_cloud(cloud_name: str):
"""Return a sky cloud object for the given cloud name string."""
import sky
cls_name = _CLOUD_CLASSES[cloud_name.lower()]
return getattr(sky, cls_name)()
[docs]
def submit_skypilot_job(config: SkyPilotConfig, job_dir: str) -> int:
"""
Launch a training job on a cloud VM via SkyPilot.
The local job config written to *job_dir*/job_config.yaml is uploaded to
REMOTE_CONFIG_PATH on the remote VM. The code in the current working
directory is synced to ~/sky_workdir via SkyPilot's workdir mechanism.
Args:
config: Populated SkyPilotConfig (including the training command).
job_dir: Local directory holding the job artifacts.
Returns:
0 on successful submission.
"""
try:
import sky
except ImportError as exc:
raise ImportError(
"SkyPilot is not installed. "
"Install it with: pip install skypilot[<cloud>] "
"(e.g. skypilot[gcp], skypilot[aws])"
) from exc
local_config_path = os.path.join(job_dir, "job_config.yaml")
# Build the environment variable dict for the remote task.
envs: dict[str, str] = dict(config.env_vars)
if config.hf_token:
envs["HF_TOKEN"] = config.hf_token
if config.wandb_key:
envs["WANDB_API_KEY"] = config.wandb_key
envs.setdefault("HF_HOME", config.hf_home)
setup_cmd = config.setup if config.setup else _DEFAULT_SETUP
task = sky.Task(
name=config.job_name or "automodel_job",
setup=setup_cmd,
run=config.command,
envs=envs,
num_nodes=config.num_nodes,
)
task.workdir = "."
task.set_file_mounts({REMOTE_CONFIG_PATH: local_config_path})
task.set_resources(
sky.Resources(
cloud=_get_cloud(config.cloud),
region=config.region,
zone=config.zone,
accelerators=config.accelerators,
use_spot=config.use_spot,
disk_size=config.disk_size,
instance_type=config.instance_type,
)
)
cluster_name = config.job_name or "automodel-cluster"
logging.info(
f"Submitting SkyPilot job '{cluster_name}' on {config.cloud} "
f"({config.accelerators}, spot={config.use_spot}, nodes={config.num_nodes})"
)
sky.launch(
task,
cluster_name=cluster_name,
detach_run=True,
stream_logs=False,
)
return 0