core.datasets.retro.query.gpt_chunk_dataset#

A GPTChunkDataset is a wrapper around a regular GPTDataset, that sequentially chunks the sample tokens into retro_chunk_length sized smaller samples.

For example, if the GPTDataset has 100 samples and a sequence length of 2048, and retro_chunk_length is 64, then the GPTChunkDataset will contain 100*(2048/64) = 3200 samples, each with length 64.

Module Contents#

Classes#

GPTChunkDataset

Pretraining chunk dataset wraps a standard GPT dataset.

Functions#

build_gpt_chunk_datasets_from_gpt_datasets

Get train, valid, test GPT chunk datasets.

API#

class core.datasets.retro.query.gpt_chunk_dataset.GPTChunkDataset(
sample_dataset: megatron.core.datasets.gpt_dataset.GPTDataset,
sample_length: int,
chunk_length: int,
)#

Bases: torch.utils.data.Dataset

Pretraining chunk dataset wraps a standard GPT dataset.

This dataset conceptually divides each sample (e.g., length 2048) into chunks (e.g., length 64) and restructures them into a list of chunks (e.g., length num_samples * num_chunks_per_sample).

Parameters:
  • sample_dataset (GPTDataset) – Original GPT dataset, with sequence_length size samples.

  • sample_length (int) – Alias for sequence_length.

  • chunk_length (int) – Retro chunk length (e.g., 64).

Initialization

__len__() int#

Get dataset length.

Returns:

Dataset length.

__getitem__(idx: int) dict#

Get sample, including represented document IDs.

Parameters:

idx (int) – Sample index.

Returns:

A sample, which contains both the chunk-length token sample (‘text’) along with all document_ids (‘doc_ids’) contained withing the full sequence_length sample.

core.datasets.retro.query.gpt_chunk_dataset.build_gpt_chunk_datasets_from_gpt_datasets(
project_dir: str,
gpt_datasets: dict,
sample_length: int,
chunk_length: int,
) dict#

Get train, valid, test GPT chunk datasets.

Parameters:
  • project_dir (str) – Retro project dir.

  • gpt_datasets (dict) – Mapping of ‘train’, ‘valid’, and ‘test’ GPT datasets (original, unchunked datasets).

  • sample_length (int) – Alias of sequence_length.

  • chunk_length (int) – Retro chunk length (e.g., 64).

Returns:

A ?