nemo_automodel.components.datasets.llm.megatron_dataset

View as Markdown

Module Contents

Classes

NameDescription
MegatronPretrainingBuild Megatron pretraining datasets and dataloaders.

Functions

NameDescription
get_list_of_filesGet the list of unique dataset prefixes (full paths without extension) from a glob pattern.
is_number_tryexceptReturns True if string is a number.
is_zipped_listCheck if the paths are zipped.
try_load_blend_from_jsonLoad a data blend configuration from a JSON file.
validate_dataset_asset_accessibilityValidate the accessibility of the dataset assets.

Data

logger

API

class nemo_automodel.components.datasets.llm.megatron_dataset.MegatronPretraining(
paths: pathlib.Path | typing.List | typing.Dict[str, typing.List],
seq_length: int = 2048,
tokenizer: typing.Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = None,
micro_batch_size: int = 4,
global_batch_size: int = 8,
create_attention_mask: bool = False,
seed: int = 1234,
split: str = '900,50,50',
index_mapping_dir: typing.Optional[str] = None,
num_dataset_builder_threads: int = 1,
num_train_samples: typing.Optional[int] = None,
num_val_samples: typing.Optional[int] = None,
num_test_samples: typing.Optional[int] = None,
trainer_max_steps: typing.Optional[int] = None,
trainer_val_check_interval: int = 1000,
trainer_limit_val_batches: typing.Union[int, float] = 1,
trainer_limit_test_batches: typing.Union[int, float] = 1,
mmap_bin_files: bool = True,
splits_to_build: typing.Optional[typing.Union[str, typing.List[str]]] = None,
object_storage_config: typing.Optional[typing.Union[typing.Dict, nemo_automodel.components.datasets.llm.megatron.indexed_dataset.ObjectStorageConfig]] = None
)

Build Megatron pretraining datasets and dataloaders.

gpt_dataset_config
GPTDatasetConfig

Get the GPT dataset configuration.

nemo_automodel.components.datasets.llm.megatron_dataset.MegatronPretraining.build()

Build the datasets using the trainer parameters provided during initialization.

nemo_automodel.components.datasets.llm.megatron_dataset.MegatronPretraining.get_dataset(
split: str
)

Get the dataset for a given split.

nemo_automodel.components.datasets.llm.megatron_dataset.get_list_of_files(
path: str
)

Get the list of unique dataset prefixes (full paths without extension) from a glob pattern.

nemo_automodel.components.datasets.llm.megatron_dataset.is_number_tryexcept(
s
)

Returns True if string is a number.

nemo_automodel.components.datasets.llm.megatron_dataset.is_zipped_list(
paths
)

Check if the paths are zipped.

nemo_automodel.components.datasets.llm.megatron_dataset.try_load_blend_from_json(
path: typing.Union[str, pathlib.Path]
) -> typing.Optional[typing.Union[typing.Dict[str, typing.List], typing.List]]

Load a data blend configuration from a JSON file.

Two top-level JSON shapes are accepted:

  1. Dict-of-splits (Automodel native form): keys are split names (‘train’, ‘valid’, ‘test’); values are path lists. Common aliases ‘valid’ / ‘val’ / ‘dev’ are normalized to ‘validation’.
  2. Flat list (Megatron-LM canonical form): a single zipped list of alternating weights and dataset prefixes. The caller uses the split= parameter to allocate this blend across train / validation / test splits.

Example flat-list JSON (Megatron-LM convention, paired with split=): [“30”, “path/to/dataset1”, “70”, “path/to/dataset2”]

Parameters:

path
Union[str, Path]

Path to a JSON file containing the blend configuration.

Returns: Optional[Union[Dict[str, List], List]]

Dictionary or list containing the blend configuration if path is

Raises:

  • FileNotFoundError: If the JSON file does not exist.
  • PermissionError: If the JSON file cannot be read.
  • ValueError: If the JSON is invalid or is neither a list nor a dict.
nemo_automodel.components.datasets.llm.megatron_dataset.validate_dataset_asset_accessibility(
paths,
object_storage_config = None
)

Validate the accessibility of the dataset assets. Skips local-filesystem checks for S3/MSC paths when object_storage_config is provided.

nemo_automodel.components.datasets.llm.megatron_dataset.logger = logging.getLogger(__name__)