core.datasets.retro.query.gpt_chunk_dataset#
A GPTChunkDataset is a wrapper around a regular GPTDataset, that sequentially
chunks the sample tokens into retro_chunk_length sized smaller samples.
For example, if the GPTDataset has 100 samples and a sequence length of 2048, and retro_chunk_length is 64, then the GPTChunkDataset will contain 100*(2048/64) = 3200 samples, each with length 64.
Module Contents#
Classes#
Pretraining chunk dataset wraps a standard GPT dataset. |
Functions#
Get train, valid, test GPT chunk datasets. |
API#
- class core.datasets.retro.query.gpt_chunk_dataset.GPTChunkDataset(
- sample_dataset: megatron.core.datasets.gpt_dataset.GPTDataset,
- sample_length: int,
- chunk_length: int,
Bases:
torch.utils.data.DatasetPretraining chunk dataset wraps a standard GPT dataset.
This dataset conceptually divides each sample (e.g., length 2048) into chunks (e.g., length 64) and restructures them into a list of chunks (e.g., length num_samples * num_chunks_per_sample).
- Parameters:
sample_dataset (GPTDataset) – Original GPT dataset, with
sequence_lengthsize samples.sample_length (int) – Alias for
sequence_length.chunk_length (int) – Retro chunk length (e.g., 64).
Initialization
- __len__() int#
Get dataset length.
- Returns:
Dataset length.
- __getitem__(idx: int) dict#
Get sample, including represented document IDs.
- Parameters:
idx (int) – Sample index.
- Returns:
A sample, which contains both the chunk-length token sample (‘text’) along with all document_ids (‘doc_ids’) contained withing the full
sequence_lengthsample.
- core.datasets.retro.query.gpt_chunk_dataset.build_gpt_chunk_datasets_from_gpt_datasets(
- project_dir: str,
- gpt_datasets: dict,
- sample_length: int,
- chunk_length: int,
Get train, valid, test GPT chunk datasets.
- Parameters:
project_dir (str) – Retro project dir.
gpt_datasets (dict) – Mapping of ‘train’, ‘valid’, and ‘test’ GPT datasets (original, unchunked datasets).
sample_length (int) – Alias of
sequence_length.chunk_length (int) – Retro chunk length (e.g., 64).
- Returns:
A
?