nemo_gym.train_data_utils#

Module Contents#

Classes#

TrainDataProcessorConfig

Prepare and validate training data, generating metrics and statistics for datasets.

Accumulator

AvgMinMax

StringMetrics

DatasetMetrics

DatasetValidatorState

TrainDataProcessor

Functions#

aggregate_other_metrics

Combines misc items (those other than response/response create params) into current metrics

postprocess_other_metrics

Aggregates metrics and merges current metrics (containing only AvgMinMax) with StringMetrics

compute_sample_metrics

validate_backend_credentials

Check if required env variables are present for the chosen backend

prepare_data

API#

class nemo_gym.train_data_utils.TrainDataProcessorConfig(/, **data: typing.Any)[source]#

Bases: nemo_gym.config_types.BaseNeMoGymCLIConfig

Prepare and validate training data, generating metrics and statistics for datasets.

Examples:

config_paths="resources_servers/example_multi_step/configs/example_multi_step.yaml,\
responses_api_models/openai_model/configs/openai_model.yaml"
ng_prepare_data "+config_paths=[${config_paths}]"         +output_dirpath=data/example_multi_step         +mode=example_validation

Initialization

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

output_dirpath: str#

‘Field(…)’

mode: Union[Literal[train_preparation], Literal[example_validation]]#

‘Field(…)’

should_download: bool#

‘Field(…)’

data_source: Literal[gitlab, huggingface]#

‘Field(…)’

property in_scope_dataset_types: List[nemo_gym.config_types.DatasetType]#
class nemo_gym.train_data_utils.Accumulator(/, **data: typing.Any)[source]#

Bases: pydantic.BaseModel

is_aggregated: bool#

‘Field(…)’

add(other: Self) None[source]#
abstractmethod _add(other: Self) None[source]#
aggregate() Self[source]#
abstractmethod _aggregate() Self[source]#
class nemo_gym.train_data_utils.AvgMinMax(/, **data: typing.Any)[source]#

Bases: nemo_gym.train_data_utils.Accumulator

model_config#

‘ConfigDict(…)’

total: int#

‘Field(…)’

average: float#

‘Field(…)’

min: float#

‘Field(…)’

max: float#

‘Field(…)’

stddev: float#

‘Field(…)’

mean: float#

‘Field(…)’

M2: float#

‘Field(…)’

observe(x: float) None[source]#
_add(other: Self) None[source]#
_aggregate() Self[source]#
class nemo_gym.train_data_utils.StringMetrics(/, **data: typing.Any)[source]#

Bases: pydantic.BaseModel

unique_count: int#

None

total_count: int#

None

class nemo_gym.train_data_utils.DatasetMetrics(/, **data: typing.Any)[source]#

Bases: nemo_gym.train_data_utils.Accumulator

model_config#

‘ConfigDict(…)’

number_of_examples: int#

‘Field(…)’

number_of_tools: nemo_gym.train_data_utils.AvgMinMax#

‘Field(…)’

json_dumped_number_of_words: nemo_gym.train_data_utils.AvgMinMax#

‘Field(…)’

number_of_turns: nemo_gym.train_data_utils.AvgMinMax#

‘Field(…)’

temperature: nemo_gym.train_data_utils.AvgMinMax#

‘Field(…)’

_add(other: Self) None[source]#
_aggregate() Self[source]#
nemo_gym.train_data_utils.aggregate_other_metrics(
metrics: Dict[str, Any],
sample: Dict[str, Any],
) None[source]#

Combines misc items (those other than response/response create params) into current metrics

nemo_gym.train_data_utils.postprocess_other_metrics(
metrics: nemo_gym.train_data_utils.DatasetMetrics,
other_metrics: Dict[str, Any],
) None[source]#

Aggregates metrics and merges current metrics (containing only AvgMinMax) with StringMetrics

nemo_gym.train_data_utils.compute_sample_metrics(
sample_dict_str: str,
) Tuple[nemo_gym.train_data_utils.DatasetMetrics, bool][source]#
class nemo_gym.train_data_utils.DatasetValidatorState(/, **data: typing.Any)[source]#

Bases: pydantic.BaseModel

model_config#

‘ConfigDict(…)’

metrics: nemo_gym.train_data_utils.DatasetMetrics#

‘Field(…)’

key_counts: collections.Counter#

‘Field(…)’

offending_example_idxs: List[int]#

‘Field(…)’

other_metrics: Dict[str, Any]#

‘Field(…)’

class nemo_gym.train_data_utils.TrainDataProcessor(/, **data: typing.Any)[source]#

Bases: pydantic.BaseModel

run(global_config_dict: omegaconf.DictConfig)[source]#

See the README section “How To: Prepare and validate data for PR submission or RL training”

_print_title(title: str) None[source]#
load_and_validate_server_instance_configs(
config: nemo_gym.train_data_utils.TrainDataProcessorConfig,
global_config_dict: omegaconf.DictConfig,
) List[nemo_gym.config_types.ServerInstanceConfig][source]#
load_datasets(
config: nemo_gym.train_data_utils.TrainDataProcessorConfig,
server_instance_configs: List[nemo_gym.config_types.ServerInstanceConfig],
) None[source]#
_validate_samples_and_aggregate_metrics_single_sample(
state: nemo_gym.train_data_utils.DatasetValidatorState,
sample_idx: int,
sample_dict_str: str,
) None[source]#
_iter_dataset_lines(
dataset_config: nemo_gym.config_types.DatasetConfig,
)[source]#
_validate_samples_and_aggregate_metrics_single_dataset(
dataset_config: nemo_gym.config_types.DatasetConfig,
) nemo_gym.train_data_utils.DatasetValidatorState[source]#
_validate_aggregate_metrics(
aggregate_metrics_dict: Dict,
metrics_fpath: pathlib.Path,
) Optional[pathlib.Path][source]#

Returns the conflicting metrics fpath if invalid. Else returns None

validate_samples_and_aggregate_metrics(
server_instance_configs: List[nemo_gym.config_types.ServerInstanceConfig],
) Dict[str, nemo_gym.train_data_utils.DatasetMetrics][source]#
_collate_samples_single_type(
type: nemo_gym.config_types.DatasetType,
server_instance_configs: List[nemo_gym.config_types.ServerInstanceConfig],
) List[pathlib.Path][source]#
collate_samples(
config: nemo_gym.train_data_utils.TrainDataProcessorConfig,
server_instance_configs: List[nemo_gym.config_types.ServerInstanceConfig],
dataset_type_to_aggregate_metrics: Dict[str, nemo_gym.train_data_utils.DatasetMetrics],
) None[source]#
nemo_gym.train_data_utils.validate_backend_credentials(backend: str) tuple[bool, str][source]#

Check if required env variables are present for the chosen backend

nemo_gym.train_data_utils.prepare_data()[source]#