Source code for nemo_gym.config_types

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
from typing import Any, ClassVar, Dict, List, Literal, Optional, Set, Tuple, Union

import rich
from omegaconf import DictConfig, OmegaConf
from pydantic import (
    BaseModel,
    ConfigDict,
    Field,
    TypeAdapter,
    ValidationError,
    model_validator,
)
from pydantic_core import PydanticUndefined
from rich.markdown import Markdown
from rich.text import Text


########################################
# Base CLI configs
########################################


[docs] class BaseNeMoGymCLIConfig(BaseModel):
[docs] @model_validator(mode="before") @classmethod def pre_process(cls, data): if not (data.get("h") or data.get("help")): return data rich.print(f"""Displaying help for [bold]{cls.__name__}[/bold] """) # We use __doc__ directly here since inspect.getdoc will inherit the doc from parent classes. class_doc = cls.__doc__ if class_doc: rich.print(f"""[bold]Description[/bold] ----------- {class_doc.strip()} """) # Render docstring as Markdown md = Markdown(class_doc.strip()) rich.print(md) fields = cls.model_fields.items() if fields: rich.print("""[bold]Parameters[/bold] ----------""") prefixes: List[Text] = [] suffixes: List[Text] = [] for field_name, field in fields: description_str = field.description if field.description else "" # Not sure if there is a better way to get this annotation_str, e.g. using typing.get_args or typing.get_origin annotation_str = ( field.annotation.__name__ if isinstance(field.annotation, type) else str(field.annotation) ) annotation_str = annotation_str.replace("typing.", "") # Add default value information if available if field.default is not PydanticUndefined and field.default is not None: default_str = f" [default: {field.default}]" description_str = description_str + default_str if description_str else default_str.strip() elif field.default_factory is not None: default_str = " [default: <factory>]" description_str = description_str + default_str if description_str else default_str.strip() elif field.default is PydanticUndefined and field.is_required(): default_str = " [required]" description_str = description_str + default_str if description_str else default_str.strip() prefixes.append(Text.from_markup(f"- [blue]{field_name}[/blue] [yellow]({annotation_str})[/yellow]")) suffixes.append(description_str) max_prefix_length = max(map(len, prefixes)) ljust_length = max_prefix_length + 3 for prefix, suffix in zip(prefixes, suffixes): prefix.align("left", ljust_length) rich.print(prefix + suffix) else: print("There are no arguments to this CLI command!") # Exit after help is printed. exit()
######################################## # Server references # # We enable servers to reference other servers. The way they do so is through these schemas below. ########################################
[docs] class ModelServerRef(BaseModel): type: Literal["responses_api_models"] name: str
[docs] class ResourcesServerRef(BaseModel): type: Literal["resources_servers"] name: str
[docs] class AgentServerRef(BaseModel): type: Literal["responses_api_agents"] name: str
ServerRef = Union[ModelServerRef, ResourcesServerRef, AgentServerRef] ServerRefTypeAdapter = TypeAdapter(ServerRef)
[docs] def is_server_ref(config_dict: DictConfig) -> Optional[ServerRef]: try: return ServerRefTypeAdapter.validate_python(config_dict) except ValidationError: return None
######################################## # Dataset configs for handling and upload/download ########################################
[docs] class UploadJsonlDatasetGitlabConfig(BaseNeMoGymCLIConfig): """ Upload a local jsonl dataset artifact to Gitlab. Examples: ```bash ng_upload_dataset_to_gitlab \ +dataset_name=example_multi_step \ +version=0.0.1 \ +input_jsonl_fpath=data/train.jsonl ``` """ dataset_name: str = Field(description="The dataset name.") version: str = Field(description="The version of this dataset. Must be in the format `x.x.x`.") input_jsonl_fpath: str = Field(description="Path to the jsonl file to upload.")
[docs] class JsonlDatasetGitlabIdentifer(BaseModel): dataset_name: str version: str artifact_fpath: str
[docs] class DownloadJsonlDatasetGitlabConfig(JsonlDatasetGitlabIdentifer, BaseNeMoGymCLIConfig): """ Download a JSONL dataset from GitLab Model Registry. Examples: ```bash ng_download_dataset_from_gitlab \ +dataset_name=example_multi_step \ +version=0.0.1 \ +artifact_fpath=train.jsonl \ +output_fpath=data/train.jsonl ``` """ dataset_name: str = Field(description="The dataset name.") version: str = Field(description="The version of this dataset. Must be in the format `x.x.x`.") artifact_fpath: str = Field(description="The filepath to the artifact to download.") output_fpath: str = Field(description="Where to save the downloaded dataset.")
[docs] class DeleteJsonlDatasetGitlabConfig(BaseNeMoGymCLIConfig): """ Delete a dataset from GitLab Model Registry (prompts for confirmation). Examples: ```bash ng_delete_dataset_from_gitlab +dataset_name=old_dataset ``` """ dataset_name: str = Field(description="Name of the dataset to delete from GitLab.")
[docs] class JsonlDatasetHuggingFaceIdentifer(BaseModel): repo_id: str = Field(description="The repo id.") artifact_fpath: Optional[str] = Field( default=None, description="Path to specific file in HuggingFace repo (e.g., 'train.jsonl'). If omitted, load_dataset will be used with split.", )
[docs] class BaseUploadJsonlDatasetHuggingFaceConfig(BaseNeMoGymCLIConfig): """ Upload a JSONL dataset to HuggingFace Hub with automatic naming based on domain and resource server. Examples: ```bash resource_config_path="resources_servers/example_multi_step/configs/example_multi_step.yaml" ng_upload_dataset_to_hf \ +dataset_name=my_dataset \ +input_jsonl_fpath=data/train.jsonl \ +resource_config_path=${resource_config_path} ``` """ hf_token: str = Field(description="HuggingFace API token for authentication.") hf_organization: str = Field(description="HuggingFace organization name where dataset will be uploaded.") hf_collection_name: str = Field(description="HuggingFace collection name for organizing datasets.") hf_collection_slug: str = Field(description="Alphanumeric collection slug found at the end of collection URI.") dataset_name: Optional[str] = Field( default=None, description="Name of the dataset (will be combined with domain and resource server name)." ) input_jsonl_fpath: str = Field(description="Path to the local jsonl file to upload.") resource_config_path: str = Field( description="Path to resource server config file (used to extract domain for naming convention)." ) hf_dataset_prefix: str = Field( default="Nemotron-RL", description="Prefix prepended to dataset name (default: 'NeMo-Gym')." ) split: Literal["train", "validation", "test"] = Field( default="train", description="Dataset split type (e.g., 'train', 'validation', 'test'). Format validation only applies to 'train' splits.", ) create_pr: bool = Field( default=False, description="Create a pull request instead of pushing directly. Required for repos where you do not have write access.", ) revision: Optional[str] = Field( default=None, description="Git revision (branch name) to upload to. Use the same revision for multiple files to upload to the same PR. If not provided with create_pr=True, a new branch/PR will be created automatically.", ) commit_message: Optional[str] = Field( default=None, description="Custom commit message. If not provided, HuggingFace auto-generates one." ) commit_description: Optional[str] = Field( default=None, description="Optional commit description with additional context." )
[docs] class UploadJsonlDatasetHuggingFaceConfig(BaseUploadJsonlDatasetHuggingFaceConfig): """ Upload a JSONL dataset to HuggingFace Hub and automatically delete from GitLab after successful upload. This command always deletes the dataset from GitLab after uploading to HuggingFace. Use `ng_upload_dataset_to_hf` if you want optional deletion control. Examples: ```bash resource_config_path="resources_servers/example_multi_step/configs/example_multi_step.yaml" ng_gitlab_to_hf_dataset \ +dataset_name=my_dataset \ +input_jsonl_fpath=data/train.jsonl \ +resource_config_path=${resource_config_path} ``` """ forbidden_fields: ClassVar[Set[str]] = {"delete_from_gitlab"}
[docs] @model_validator(mode="before") def check_forbidden_fields(cls, data): if isinstance(data, dict) or hasattr(data, "keys"): forbidden = cls.forbidden_fields.intersection(set(data.keys())) if forbidden: raise ValueError(f"Forbidden fields present: {forbidden}") return data
[docs] class UploadJsonlDatasetHuggingFaceMaybeDeleteConfig(BaseUploadJsonlDatasetHuggingFaceConfig): """ Upload a JSONL dataset to HuggingFace Hub with optional GitLab deletion after successful upload. Examples: ```bash resource_config_path="resources_servers/example_multi_step/configs/example_multi_step.yaml" ng_upload_dataset_to_hf \ +dataset_name=my_dataset \ +input_jsonl_fpath=data/train.jsonl \ +resource_config_path=${resource_config_path} \ +delete_from_gitlab=true ``` """ delete_from_gitlab: Optional[bool] = Field( default=False, description="Delete the dataset from GitLab after successful upload to HuggingFace." )
[docs] class DownloadJsonlDatasetHuggingFaceConfig(JsonlDatasetHuggingFaceIdentifer, BaseNeMoGymCLIConfig): """ Download a JSONL dataset from HuggingFace Hub to local filesystem. Examples: ```bash ng_download_dataset_from_hf \ +repo_id=NVIDIA/NeMo-Gym-Math-example_multi_step-v1 \ +artifact_fpath=train.jsonl \ +output_fpath=data/train.jsonl ``` """ output_dirpath: Optional[str] = Field( default=None, description="Directory to save the downloaded dataset. Files will be named {split}.jsonl. If split is omitted, all available splits are downloaded.", ) output_fpath: Optional[str] = Field( default=None, description="Exact local file path where the downloaded dataset will be saved. Requires `artifact_fpath` or `split`. Overrides output_dirpath.", ) hf_token: Optional[str] = Field(default=None, description="HuggingFace API token for authentication.") split: Optional[Literal["train", "validation", "test"]] = Field( default=None, description="Dataset split to download. Omit to download all available splits." )
[docs] @model_validator(mode="after") def check_output_path(self) -> "DownloadJsonlDatasetHuggingFaceConfig": if not self.output_dirpath and not self.output_fpath: raise ValueError("Either output_dirpath or output_fpath must be provided") if self.output_dirpath and self.output_fpath: raise ValueError("Cannot specify both output_dirpath and output_fpath") if self.artifact_fpath and self.split: raise ValueError( "Cannot specify both artifact_fpath and split. Use artifact_fpath for targeting a raw file, or split for structured datasets." ) # Prevent output_fpath without split when not using artifact_fpath if self.output_fpath and not self.split and not self.artifact_fpath: raise ValueError( "When using output_fpath without artifact_fpath, split must be specified. Use output_dirpath to download all splits." ) return self
DatasetType = Union[Literal["train"], Literal["validation"], Literal["example"]]
[docs] class DatasetConfig(BaseModel): name: str type: DatasetType jsonl_fpath: str num_repeats: int = Field(default=1, ge=1) gitlab_identifier: Optional[JsonlDatasetGitlabIdentifer] = None huggingface_identifier: Optional[JsonlDatasetHuggingFaceIdentifer] = None license: Optional[ Union[ Literal["Apache 2.0"], Literal["MIT"], Literal["Creative Commons Attribution 4.0 International"], Literal["Creative Commons Attribution-ShareAlike 4.0 International"], Literal["NVIDIA Internal Use Only, Do Not Distribute"], Literal["TBD"], ] ] = None
[docs] @model_validator(mode="after") def check_train_validation_sets(self) -> "DatasetConfig": if self.type in ["train", "validation"]: assert self.license is not None, f"A license is required for {self.name}" return self
######################################## # Base server config classes ########################################
[docs] class Domain(str, Enum): MATH = "math" CODING = "coding" AGENT = "agent" KNOWLEDGE = "knowledge" INSTRUCTION_FOLLOWING = "instruction_following" LONG_CONTEXT = "long_context" SAFETY = "safety" GAMES = "games" TRANSLATION = "translation" E2E = "e2e" OTHER = "other"
[docs] class BaseServerConfig(BaseModel): host: str port: int
[docs] class BaseRunServerConfig(BaseServerConfig): entrypoint: str domain: Optional[Domain] = None # Only required for resource servers
[docs] class BaseRunServerInstanceConfig(BaseRunServerConfig): name: str # This name is unique at runtime.
######################################## # Server type and server instance configs ########################################
[docs] class BaseRunServerTypeConfig(BaseRunServerConfig): model_config = ConfigDict(extra="allow") host: Optional[str] = None port: Optional[int] = None datasets: Optional[List[DatasetConfig]] = None
[docs] class BaseServerTypeConfig(BaseModel): SERVER_TYPE: ClassVar[ Union[ Literal["responses_api_models"], Literal["resources_servers"], Literal["responses_api_agents"], ] ]
[docs] class ResponsesAPIModelServerTypeConfig(BaseServerTypeConfig): SERVER_TYPE: ClassVar[Literal["responses_api_models"]] = "responses_api_models" model_config = ConfigDict(extra="allow") responses_api_models: Dict[str, BaseRunServerTypeConfig] = Field(min_length=1, max_length=1)
[docs] class ResourcesServerTypeConfig(BaseServerTypeConfig): SERVER_TYPE: ClassVar[Literal["resources_servers"]] = "resources_servers" model_config = ConfigDict(extra="allow") resources_servers: Dict[str, BaseRunServerTypeConfig] = Field(min_length=1, max_length=1)
[docs] class ResponsesAPIAgentServerTypeConfig(BaseServerTypeConfig): SERVER_TYPE: ClassVar[Literal["responses_api_agents"]] = "responses_api_agents" model_config = ConfigDict(extra="allow") responses_api_agents: Dict[str, BaseRunServerTypeConfig] = Field(min_length=1, max_length=1)
ServerTypeConfig = Union[ ResponsesAPIModelServerTypeConfig, ResourcesServerTypeConfig, ResponsesAPIAgentServerTypeConfig, ]
[docs] class BaseServerInstanceConfig(BaseServerTypeConfig): model_config = ConfigDict(arbitrary_types_allowed=True) name: str server_type_config_dict: DictConfig = Field(exclude=True)
[docs] @model_validator(mode="after") def validate_domain_for_resource_server(self) -> "BaseServerInstanceConfig": config = self.get_inner_run_server_config() if self.SERVER_TYPE == "resources_servers": assert config.domain is not None, "A domain is required for resource servers." else: # Remove domain field from Model and Agent servers. if hasattr(config, "domain"): del config.domain return self
[docs] def get_server_ref(self) -> ServerRef: return is_server_ref({"type": self.SERVER_TYPE, "name": self.name})
[docs] def get_inner_run_server_config_dict(self) -> DictConfig: server_type_name = list(getattr(self, self.SERVER_TYPE))[0] return self.server_type_config_dict[self.SERVER_TYPE][server_type_name]
[docs] def get_inner_run_server_config(self) -> BaseRunServerTypeConfig: return list(getattr(self, self.SERVER_TYPE).values())[0]
@property def datasets(self) -> Optional[List[DatasetConfig]]: return self.get_inner_run_server_config().datasets
[docs] class ResponsesAPIModelServerInstanceConfig(ResponsesAPIModelServerTypeConfig, BaseServerInstanceConfig): pass
[docs] class ResourcesServerInstanceConfig(ResourcesServerTypeConfig, BaseServerInstanceConfig): pass
[docs] class ResponsesAPIAgentServerInstanceConfig(ResponsesAPIAgentServerTypeConfig, BaseServerInstanceConfig): pass
ServerInstanceConfig = Union[ ResponsesAPIModelServerInstanceConfig, ResourcesServerInstanceConfig, ResponsesAPIAgentServerInstanceConfig, ] ServerInstanceConfigTypeAdapter = TypeAdapter(ServerInstanceConfig)
[docs] def maybe_get_server_instance_config( name: str, server_type_config_dict: Any ) -> Tuple[Optional[ServerInstanceConfig], Optional[ValidationError]]: """Returns ServerInstanceConfig if a valid server, otherwise None with error details""" if not isinstance(server_type_config_dict, DictConfig): return None, None maybe_server_instance_config_dict = { "name": name, "server_type_config_dict": server_type_config_dict, **OmegaConf.to_container(server_type_config_dict), } try: config = ServerInstanceConfigTypeAdapter.validate_python(maybe_server_instance_config_dict) return config, None except ValidationError as e: return None, e
[docs] def is_almost_server(server_type_config_dict: Any) -> bool: """Detects if a config looks like a server but might fail validation.""" from nemo_gym.global_config import ENTRYPOINT_KEY_NAME if not isinstance(server_type_config_dict, DictConfig): return False # Check for server type. server_type_keys = ["responses_api_models", "resources_servers", "responses_api_agents"] has_server_type = any(key in server_type_config_dict for key in server_type_keys) if not has_server_type: return False # Check for entrypoint presence. for server_type_key in server_type_keys: if server_type_key in server_type_config_dict: inner_dict = server_type_config_dict[server_type_key] if isinstance(inner_dict, DictConfig): for config in inner_dict.values(): if isinstance(config, DictConfig) and ENTRYPOINT_KEY_NAME in config: return True return False
######################################## # Train dataset configs ######################################## AGENT_REF_KEY = "agent_ref"