core.datasets.retro.query.retro_dataset#
A RetroDataset wraps both:
A GPTDataset (which is nested as GPTChunkDataset -> MultiSplitGPTDataset -> GPTDataset).
Neighbor IDs of chunks in the chunk database, that were saved during preprocessing.
Both the GPT sample data and the neighbor IDs are returned within a sample from this dataset.
Module Contents#
Classes#
Dataset of retro samples. |
Functions#
Get train, valid, test retro datasets. |
API#
- class core.datasets.retro.query.retro_dataset.RetroDataset(
- num_queried_samples: int,
- num_neighbors: int,
- num_retrieved_chunks: int,
- block_size: int,
- db_dataset: megatron.core.datasets.retro.db.dataset.DBDataset,
- chunk_dataset: core.datasets.retro.query.gpt_chunk_dataset.GPTChunkDataset,
- neighbor_path_map: megatron.core.datasets.retro.utils.BlockPathMap,
Bases:
torch.utils.data.DatasetDataset of retro samples.
Each sample contains the original GPT sample, along with the token IDs of each neighbor of each chunk within the sequence. Neighbor array has shape (num_chunks_per_sample, num_neighbors, num_retrieved_tokens).
** Note: chunk dataset wraps original GPT dataset (see gpt_chunk_dataset.py).
- Parameters:
num_queried_samples (int) – Total number of queried samples.
num_neighbors (int) – Total number of saved neighbors.
num_retrieved_chunks (int) – Number of retrieved chunks (e.g., 2 for neighbor + continuation).
block_size (int) – Number of neighbor entries per file.
db_dataset (DBDataset) – Chunk database used for retrieval.
chunk_dataset (GPTChunkDataset) – GPT chunk dataset, which is a wrapper around a standard GPT dataset that breaks each sample into chunks.
neighbor_path_map (BlockPathMap) – Mapping of neighbor ID to file path.
Initialization
- __len__() int#
Dataset length.
- Returns:
Number of samples in dataset.
- __getitem__(sample_idx: int) dict#
Get dataset sample.
- Parameters:
sample_idx (int) – Index of sample in dataset.
- Returns:
A dict consisting of GPT sample (attribute ‘text’) and corresponding neighbor chunk IDs (‘neighbor_chunks’, for indexing chunk database) and neighbor token IDs (corresponding chunk database GPT tokens).
- core.datasets.retro.query.retro_dataset.get_retro_datasets(
- config: megatron.core.models.retro.RetroConfig,
- gpt_datasets: dict,
- sample_length: int,
- eod_token_id: int,
Get train, valid, test retro datasets.
- Parameters:
config (RetroConfig) – Retro preprocessing config.
gpt_datasets (dict) – Mapping of data split key (‘train’, ‘valid’, or ‘test’) to the original sequence-length GPT dataset (i.e., not the chunk dataset).
sample_length (int) – Alias to
sequence_length.eod_token_id (int) – GPT EOD token ID.
- Returns:
A tuple of ‘train’, ‘valid’, and ‘test’
RetroDatasets.