nemo_gym.train_data_utils

View as Markdown

Module Contents

Classes

NameDescription
Accumulator-
AvgMinMax-
DatasetMetrics-
DatasetValidatorState-
StringMetrics-
TrainDataProcessor-
TrainDataProcessorConfigPrepare and validate training data, generating metrics and statistics for datasets.

Functions

NameDescription
aggregate_other_metricsCombines misc items (those other than response/response create params) into current metrics
compute_sample_metrics-
postprocess_other_metricsAggregates metrics and merges current metrics (containing only AvgMinMax) with StringMetrics
prepare_data-
validate_backend_credentialsCheck if required env variables are present for the chosen backend

API

class nemo_gym.train_data_utils.Accumulator()

Bases: BaseModel

is_aggregated
bool = Field(default=False, exclude=True)
nemo_gym.train_data_utils.Accumulator._add(
other: typing.Self
) -> None
abstract
nemo_gym.train_data_utils.Accumulator._aggregate() -> typing.Self
abstract
nemo_gym.train_data_utils.Accumulator.add(
other: typing.Self
) -> None
nemo_gym.train_data_utils.Accumulator.aggregate() -> typing.Self
class nemo_gym.train_data_utils.AvgMinMax()

Bases: Accumulator

M2
float = Field(default=0, exclude=True)
average
float = Field(serialization_alias='Average', default=0)
max
float
mean
float = Field(default=0, exclude=True)
min
float
model_config
= ConfigDict(arbitrary_types_allowed=True)
stddev
float
total
int
nemo_gym.train_data_utils.AvgMinMax._add(
other: typing.Self
) -> None
nemo_gym.train_data_utils.AvgMinMax._aggregate() -> typing.Self
nemo_gym.train_data_utils.AvgMinMax.observe(
x: float
) -> None
class nemo_gym.train_data_utils.DatasetMetrics()

Bases: Accumulator

json_dumped_number_of_words
AvgMinMax
model_config
= ConfigDict(extra='allow')
number_of_examples
int
number_of_tools
AvgMinMax
number_of_turns
AvgMinMax
temperature
AvgMinMax
nemo_gym.train_data_utils.DatasetMetrics._add(
other: typing.Self
) -> None
nemo_gym.train_data_utils.DatasetMetrics._aggregate() -> typing.Self
class nemo_gym.train_data_utils.DatasetValidatorState()

Bases: BaseModel

key_counts
Counter = Field(default_factory=Counter)
metrics
DatasetMetrics = Field(default_factory=DatasetMetrics)
model_config
= ConfigDict(arbitrary_types_allowed=True)
offending_example_idxs
List[int] = Field(default_factory=list)
other_metrics
Dict[str, Any] = Field(default_factory=dict)
class nemo_gym.train_data_utils.StringMetrics()

Bases: BaseModel

total_count
int
unique_count
int
class nemo_gym.train_data_utils.TrainDataProcessor()

Bases: BaseModel

nemo_gym.train_data_utils.TrainDataProcessor._collate_samples_single_type(
type: nemo_gym.config_types.DatasetType,
server_instance_configs: typing.List[nemo_gym.config_types.ServerInstanceConfig]
) -> typing.List[pathlib.Path]
nemo_gym.train_data_utils.TrainDataProcessor._iter_dataset_lines(
dataset_config: nemo_gym.config_types.DatasetConfig
)
nemo_gym.train_data_utils.TrainDataProcessor._print_title(
title: str
) -> None
nemo_gym.train_data_utils.TrainDataProcessor._validate_aggregate_metrics(
aggregate_metrics_dict: typing.Dict,
metrics_fpath: pathlib.Path
) -> typing.Optional[pathlib.Path]

Returns the conflicting metrics fpath if invalid. Else returns None

nemo_gym.train_data_utils.TrainDataProcessor._validate_samples_and_aggregate_metrics_single_dataset(
dataset_config: nemo_gym.config_types.DatasetConfig
) -> nemo_gym.train_data_utils.DatasetValidatorState
nemo_gym.train_data_utils.TrainDataProcessor._validate_samples_and_aggregate_metrics_single_sample(
state: nemo_gym.train_data_utils.DatasetValidatorState,
sample_idx: int,
sample_dict_str: str
) -> None
nemo_gym.train_data_utils.TrainDataProcessor.collate_samples(
config: nemo_gym.train_data_utils.TrainDataProcessorConfig,
server_instance_configs: typing.List[nemo_gym.config_types.ServerInstanceConfig],
dataset_type_to_aggregate_metrics: typing.Dict[str, nemo_gym.train_data_utils.DatasetMetrics]
) -> None
nemo_gym.train_data_utils.TrainDataProcessor.load_and_validate_server_instance_configs(
config: nemo_gym.train_data_utils.TrainDataProcessorConfig,
global_config_dict: omegaconf.DictConfig
) -> typing.List[nemo_gym.config_types.ServerInstanceConfig]
nemo_gym.train_data_utils.TrainDataProcessor.load_datasets(
config: nemo_gym.train_data_utils.TrainDataProcessorConfig,
server_instance_configs: typing.List[nemo_gym.config_types.ServerInstanceConfig]
) -> None
nemo_gym.train_data_utils.TrainDataProcessor.run(
global_config_dict: omegaconf.DictConfig
)

See the README section “How To: Prepare and validate data for PR submission or RL training”

nemo_gym.train_data_utils.TrainDataProcessor.validate_samples_and_aggregate_metrics(
server_instance_configs: typing.List[nemo_gym.config_types.ServerInstanceConfig],
overwrite_metrics_conflicts: bool
) -> typing.Dict[str, nemo_gym.train_data_utils.DatasetMetrics]
class nemo_gym.train_data_utils.TrainDataProcessorConfig()

Bases: BaseNeMoGymCLIConfig

Prepare and validate training data, generating metrics and statistics for datasets.

Examples:

config_paths="resources_servers/example_multi_step/configs/example_multi_step.yaml,\
responses_api_models/openai_model/configs/openai_model.yaml"
ng_prepare_data "+config_paths=[${config_paths}]" +output_dirpath=data/example_multi_step +mode=example_validation
data_source
Literal['gitlab', 'huggingface']
in_scope_dataset_types
List[DatasetType]
mode
Union[Literal['train_preparation'], Literal['example_validation']]
output_dirpath
str
overwrite_metrics_conflicts
bool
should_download
bool
nemo_gym.train_data_utils.aggregate_other_metrics(
metrics: typing.Dict[str, typing.Any],
sample: typing.Dict[str, typing.Any]
) -> None

Combines misc items (those other than response/response create params) into current metrics

nemo_gym.train_data_utils.compute_sample_metrics(
sample_dict_str: str
) -> typing.Tuple[nemo_gym.train_data_utils.DatasetMetrics, bool]
nemo_gym.train_data_utils.postprocess_other_metrics(
metrics: nemo_gym.train_data_utils.DatasetMetrics,
other_metrics: typing.Dict[str, typing.Any]
) -> None

Aggregates metrics and merges current metrics (containing only AvgMinMax) with StringMetrics

nemo_gym.train_data_utils.prepare_data()
nemo_gym.train_data_utils.validate_backend_credentials(
backend: str
) -> tuple[bool, str]

Check if required env variables are present for the chosen backend