Source code for nemo_retriever.utils.hf_cache

from __future__ import annotations

import os
from pathlib import Path
from typing import Iterable, Optional

ENV_HF_CACHE_BASE_DIR = "NEMO_RETRIEVER_HF_CACHE_DIR"

HF_RUNTIME_ENV_KEYS: tuple[str, ...] = (
    "HF_TOKEN",
    "HUGGING_FACE_HUB_TOKEN",
    "HF_HOME",
    "HF_HUB_CACHE",
    "TRANSFORMERS_CACHE",
    ENV_HF_CACHE_BASE_DIR,
    "HF_ENDPOINT",
    "HF_HUB_DISABLE_IMPLICIT_TOKEN",
    "HF_HUB_ENABLE_HF_TRANSFER",
    "HF_HUB_ETAG_TIMEOUT",
    "HF_HUB_DOWNLOAD_TIMEOUT",
    "HF_HUB_DISABLE_TELEMETRY",
    "HTTP_PROXY",
    "HTTPS_PROXY",
    "NO_PROXY",
    "http_proxy",
    "https_proxy",
    "no_proxy",
    "REQUESTS_CA_BUNDLE",
    "CURL_CA_BUNDLE",
    "SSL_CERT_FILE",
)


[docs] def resolve_hf_cache_dir(explicit_hf_cache_dir: Optional[str] = None) -> str: """Resolve Hugging Face cache dir from explicit arg, env, then default.""" candidate = explicit_hf_cache_dir or os.getenv(ENV_HF_CACHE_BASE_DIR) if candidate: return str(Path(candidate).expanduser()) return str(Path.home() / ".cache" / "huggingface")
[docs] def configure_global_hf_cache_base(explicit_hf_cache_dir: Optional[str] = None) -> str: """Apply resolved HF cache base to standard Hugging Face env vars.""" cache_base = resolve_hf_cache_dir(explicit_hf_cache_dir) os.environ.setdefault("HF_HOME", cache_base) os.environ.setdefault("HF_HUB_CACHE", str(Path(cache_base) / "hub")) os.environ.setdefault("TRANSFORMERS_CACHE", str(Path(cache_base) / "transformers")) return cache_base
[docs] def collect_hf_runtime_env( *, default_hf_hub_offline: str = "0", extra_keys: Iterable[str] = (), ) -> dict[str, str]: """Collect HF-related environment variables to forward to Ray workers. Parameters ---------- default_hf_hub_offline: Value to emit for ``HF_HUB_OFFLINE`` when it is not set in the parent process environment. The default keeps online Hub checks enabled. extra_keys: Additional environment variable names to forward if they are set. Duplicates of built-in keys are ignored after their first occurrence. Returns ------- dict[str, str] Environment variables for Ray ``runtime_env["env_vars"]``. Explicitly blank environment values are preserved. """ env_vars: dict[str, str] = {} for key in (*HF_RUNTIME_ENV_KEYS, *tuple(extra_keys)): if key in env_vars: continue value = os.environ.get(key) if value is not None: env_vars[key] = value # HF_HUB_OFFLINE is emitted explicitly so every Ray worker gets a default; # passing it through extra_keys is intentionally overridden here. env_vars["HF_HUB_OFFLINE"] = os.environ.get("HF_HUB_OFFLINE", default_hf_hub_offline) return env_vars