core.datasets.retro.query.retro_dataset#

A RetroDataset wraps both:

  • A GPTDataset (which is nested as GPTChunkDataset -> MultiSplitGPTDataset -> GPTDataset).

  • Neighbor IDs of chunks in the chunk database, that were saved during preprocessing.

Both the GPT sample data and the neighbor IDs are returned within a sample from this dataset.

Module Contents#

Classes#

RetroDataset

Dataset of retro samples.

Functions#

get_retro_datasets

Get train, valid, test retro datasets.

API#

class core.datasets.retro.query.retro_dataset.RetroDataset(
num_queried_samples: int,
num_neighbors: int,
num_retrieved_chunks: int,
block_size: int,
db_dataset: megatron.core.datasets.retro.db.dataset.DBDataset,
chunk_dataset: core.datasets.retro.query.gpt_chunk_dataset.GPTChunkDataset,
neighbor_path_map: megatron.core.datasets.retro.utils.BlockPathMap,
)#

Bases: torch.utils.data.Dataset

Dataset of retro samples.

Each sample contains the original GPT sample, along with the token IDs of each neighbor of each chunk within the sequence. Neighbor array has shape (num_chunks_per_sample, num_neighbors, num_retrieved_tokens).

** Note: chunk dataset wraps original GPT dataset (see gpt_chunk_dataset.py).

Parameters:
  • num_queried_samples (int) – Total number of queried samples.

  • num_neighbors (int) – Total number of saved neighbors.

  • num_retrieved_chunks (int) – Number of retrieved chunks (e.g., 2 for neighbor + continuation).

  • block_size (int) – Number of neighbor entries per file.

  • db_dataset (DBDataset) – Chunk database used for retrieval.

  • chunk_dataset (GPTChunkDataset) – GPT chunk dataset, which is a wrapper around a standard GPT dataset that breaks each sample into chunks.

  • neighbor_path_map (BlockPathMap) – Mapping of neighbor ID to file path.

Initialization

__len__() int#

Dataset length.

Returns:

Number of samples in dataset.

__getitem__(sample_idx: int) dict#

Get dataset sample.

Parameters:

sample_idx (int) – Index of sample in dataset.

Returns:

A dict consisting of GPT sample (attribute ‘text’) and corresponding neighbor chunk IDs (‘neighbor_chunks’, for indexing chunk database) and neighbor token IDs (corresponding chunk database GPT tokens).

core.datasets.retro.query.retro_dataset.get_retro_datasets(
config: megatron.core.models.retro.RetroConfig,
gpt_datasets: dict,
sample_length: int,
eod_token_id: int,
) Tuple[Optional[core.datasets.retro.query.retro_dataset.RetroDataset], Optional[core.datasets.retro.query.retro_dataset.RetroDataset], Optional[core.datasets.retro.query.retro_dataset.RetroDataset]]#

Get train, valid, test retro datasets.

Parameters:
  • config (RetroConfig) – Retro preprocessing config.

  • gpt_datasets (dict) – Mapping of data split key (‘train’, ‘valid’, or ‘test’) to the original sequence-length GPT dataset (i.e., not the chunk dataset).

  • sample_length (int) – Alias to sequence_length.

  • eod_token_id (int) – GPT EOD token ID.

Returns:

A tuple of ‘train’, ‘valid’, and ‘test’ RetroDatasets.