nemo_gym.train_data_utils#

Module Contents#

Classes#

Functions#

aggregate_other_metrics

Combines misc items (those other than response/response create params) into current metrics

postprocess_other_metrics

Aggregates metrics and merges current metrics (containing only AvgMinMax) with StringMetrics

compute_sample_metrics

prepare_data

API#

class nemo_gym.train_data_utils.TrainDataProcessorConfig#

Bases: nemo_gym.config_types.BaseNeMoGymCLIConfig

output_dirpath: str#

‘Field(…)’

mode: Union[Literal[train_preparation], Literal[example_validation]]#

‘Field(…)’

should_download: bool#

‘Field(…)’

property in_scope_dataset_types: List[nemo_gym.config_types.DatasetType]#
class nemo_gym.train_data_utils.Accumulator#

Bases: pydantic.BaseModel

is_aggregated: bool#

‘Field(…)’

add(other: Self) None#
abstractmethod _add(other: Self) None#
aggregate() Self#
abstractmethod _aggregate() Self#
class nemo_gym.train_data_utils.AvgMinMax#

Bases: nemo_gym.train_data_utils.Accumulator

model_config#

‘ConfigDict(…)’

total: int#

‘Field(…)’

average: float#

‘Field(…)’

min: float#

‘Field(…)’

max: float#

‘Field(…)’

median: float#

‘Field(…)’

stddev: float#

‘Field(…)’

mean: float#

‘Field(…)’

M2: float#

‘Field(…)’

tdigest: tdigest.TDigest#

‘Field(…)’

T-Digest is used to estimate the Median without storing and sorting all values. The Median is essentially an approximation using the 50th percentile, which is very close to the true Median.

observe(x: float) None#
_add(other: Self) None#
_aggregate() Self#
class nemo_gym.train_data_utils.StringMetrics#

Bases: pydantic.BaseModel

unique_count: int#

None

total_count: int#

None

class nemo_gym.train_data_utils.DatasetMetrics#

Bases: nemo_gym.train_data_utils.Accumulator

model_config#

‘ConfigDict(…)’

number_of_examples: int#

‘Field(…)’

number_of_tools: nemo_gym.train_data_utils.AvgMinMax#

‘Field(…)’

json_dumped_number_of_words: nemo_gym.train_data_utils.AvgMinMax#

‘Field(…)’

number_of_turns: nemo_gym.train_data_utils.AvgMinMax#

‘Field(…)’

temperature: nemo_gym.train_data_utils.AvgMinMax#

‘Field(…)’

_add(other: Self) None#
_aggregate() Self#
nemo_gym.train_data_utils.aggregate_other_metrics(
metrics: Dict[str, Any],
sample: Dict[str, Any],
) None#

Combines misc items (those other than response/response create params) into current metrics

nemo_gym.train_data_utils.postprocess_other_metrics(
metrics: nemo_gym.train_data_utils.DatasetMetrics,
other_metrics: Dict[str, Any],
) None#

Aggregates metrics and merges current metrics (containing only AvgMinMax) with StringMetrics

nemo_gym.train_data_utils.compute_sample_metrics(
sample_dict_str: str,
) Tuple[nemo_gym.train_data_utils.DatasetMetrics, bool]#
class nemo_gym.train_data_utils.DatasetValidatorState#

Bases: pydantic.BaseModel

model_config#

‘ConfigDict(…)’

metrics: nemo_gym.train_data_utils.DatasetMetrics#

‘Field(…)’

key_counts: collections.Counter#

‘Field(…)’

offending_example_idxs: List[int]#

‘Field(…)’

other_metrics: Dict[str, Any]#

‘Field(…)’

class nemo_gym.train_data_utils.TrainDataProcessor#

Bases: pydantic.BaseModel

run(global_config_dict: omegaconf.DictConfig)#

See the README section “How To: Prepare and validate data for PR submission or RL training”

_print_title(title: str) None#
load_and_validate_server_instance_configs(
config: nemo_gym.train_data_utils.TrainDataProcessorConfig,
global_config_dict: omegaconf.DictConfig,
) List[nemo_gym.config_types.ServerInstanceConfig]#
load_datasets(
config: nemo_gym.train_data_utils.TrainDataProcessorConfig,
server_instance_configs: List[nemo_gym.config_types.ServerInstanceConfig],
) None#
_validate_samples_and_aggregate_metrics_single_sample(
state: nemo_gym.train_data_utils.DatasetValidatorState,
sample_idx: int,
sample_dict_str: str,
) None#
_iter_dataset_lines(
dataset_config: nemo_gym.config_types.DatasetConfig,
)#
_validate_samples_and_aggregate_metrics_single_dataset(
dataset_config: nemo_gym.config_types.DatasetConfig,
) nemo_gym.train_data_utils.DatasetValidatorState#
_validate_aggregate_metrics(
aggregate_metrics_dict: Dict,
metrics_fpath: pathlib.Path,
) Optional[pathlib.Path]#

Returns the conflicting metrics fpath if invalid. Else returns None

validate_samples_and_aggregate_metrics(
server_instance_configs: List[nemo_gym.config_types.ServerInstanceConfig],
) Dict[str, nemo_gym.train_data_utils.DatasetMetrics]#
_collate_samples_single_type(
type: nemo_gym.config_types.DatasetType,
server_instance_configs: List[nemo_gym.config_types.ServerInstanceConfig],
) List[pathlib.Path]#
collate_samples(
config: nemo_gym.train_data_utils.TrainDataProcessorConfig,
server_instance_configs: List[nemo_gym.config_types.ServerInstanceConfig],
dataset_type_to_aggregate_metrics: Dict[str, nemo_gym.train_data_utils.DatasetMetrics],
) None#
nemo_gym.train_data_utils.prepare_data()#