datasets package#

Data Pipeline#

Data pre-processing#

Data preprocessing is built around the following classes:

  1. IndexedDatasetBuilder

  2. IndexedDataset

At the moment, an end-to-end data preprocessing implementation is left to the user. See the class docstring(s) for more details.

IndexedDatasetBuilder#

The IndexedDatasetBuilder is capable of building and merging IndexedDataset instances.

IndexedDataset#

The IndexedDataset class is the lowest-level data interface in Megatron Core. Internally, an IndexedDataset instance references two binaries: the data file (.bin) contains document/sequence data and the index file (.idx) contains document/sequence metadata.

The index file stores dataset-level metadata first:

  • The index header, for backward compatibility

  • The index version, for backward compatibility

  • A numeric code corresponding to the data type used to write data to the data file

  • The number of sequences in the dataset

  • The number of documents in the dataset

The index file stores document-level and sequence-level metadata second:

  • In order, the number of elements per sequence

  • In order, the byte offset (pointer) per sequence

  • In order, the consecutive sequence index range [...) per document

  • In order, the mode per sequence (in the multimodal case)

Data loading: construction#

Building the data loaders is a distributed-aware process built around the following classes:

  1. BlendedMegatronDatasetConfig

  2. BlendedMegatronDatasetBuilder

  3. IndexedDataset

  4. MegatronDataset

  5. BlendedDataset

See the class docstrings for more details.

BlendedMegatronDatasetConfig (extendable)#

The BlendedMegatronDatasetConfig class parameterizes the BlendedMegatronDatasetBuilder and in turn the MegatronDataset and BlendedDataset.

Different training/inference regimes will require different extensions e.g. the GPTDatasetConfig

BlendedMegatronDatasetBuilder#

The BlendedMegatronDatasetBuilder class builds the highest-level data interfaces in Megatron Core.

NB: All ranks should attempt to build the dataset via the BlendedMegatronDatasetBuilder or the program will hang. Which ranks follow through on their attempts can be controlled via the BlendedMegatronDatasetConfig.

IndexedDataset#

The IndexedDataset class is the lowest-level data interface in Megatron Core.

The IndexedDataset should already exist on disk before attempting to build any of the high-level data interfaces.

MegatronDataset (extendable)#

The MegatronDataset abstract class is a high-level data interface in Megatron Core. It is an abstraction built upon the IndexedDataset.

Different training/inference regimes will require different extensions e.g. the GPTDataset

BlendedDataset#

The BlendedDataset class is a high-level data interface in Megatron Core. It is an abstraction built upon the MegatronDataset.

The BlendedDataset is only necessary when a blend multiple data distributions, i.e. multiple MegatronDataset instances, should contribute to a certain dataset split. The blend can be controlled via the BlendedMegatronDatasetConfig.

Data loading: implementation#

GPTDataset#

The GPTDataset is parameterized by the following variables: the underlying IndexedDataset instance indexed_dataset, the split indices indexed_indices (the congituous subset of document or sequence indices used for training, validation, and testing), the number of samples N, the sequence length S, and the random seed R.

The GPTDataset creates three index mappings to facilitate lookup: (1) the document index, (2) the sample index, and (3) the shuffle index.

  1. The document index Do_idx is a 1-D array mapping from i to document index of length E * |indexed_indices| where E corresponds to the minimum number of epochs such that E * |indexed_indices| >= N. The document index is shuffled according to R.

    Given:
    
    N = 15
    indexed_indices = [5, 6, 7, 8, 9]
    E = 3
    
    Then, for example:
    
    Do_idx = [8, 8, 9, 6, 7, 5, 8, 5, 6, 6, 5, 9, 7, 7, 9]
    
  2. The sample index Sa_idx is a 2-D array mapping from j to pairs of (i, Do_idx[ i ] offset) of shape [N + 1, 2]. The rows j and j + 1 serve as the left and right bounds for the j-th sample.

    Given:
    
    S = 1024
    
    Then, for example:
    
    Sa_idx[0] = (0, 0)
    Sa_idx[1] = (0, 1024)       => Do_idx[0] has length greater than S
    Sa_idx[2] = (1, 512)        => Do_idx[0] has length 1536
    Sa_idx[3] = (2, 0)          => Do_idx[1] has length 1536
    Sa_idx[4] = (5, 300)        => Do_idx[2:5] are shorter documents relative to Do_idx[0:2]
    Sa_idx[5] = (6, 24)         => Do_idx[5] has length 1300
    
  3. The shuffle index Sh_idx is a 1-D array mapping from k to j of length N. The shuffle index is shuffled according to R.

    Given
    
    N = 10
    
    Then, for example:
    
    Sh_idx = [4, 0, 2, 6, 1, 9, 5, 8, 7, 3]
    

To query the GPTDataset for the k-th sample we do the following

  • Use the shuffle index to get the index j into the sample index.

    j = Sh_idx[k]
    
  • Use the sample index to get the left and right sample-bounding indices into the document index and the starting token offset for each document.

    i, offset = Sa_idx[j]
    i_next, offset_next = Sa_idx[j + 1]
    
  • Use the document index to retrieve S tokens from consecutive (in the document index) documents.

    sample = []
    sample += indexed_dataset[Do_idx[i]][offset:]
    if i != i_next:
        sample += indexed_dataset[Do_idx[i + 1:i_next]]
    sample += indexed_dataset[Do_idx[i_next]][:offset_next]
    

To save time during initialization, each index is built/cached sequentially on one process rank and subsequently loaded in parallel on other process ranks. The cached indices are unique to a hash generated in the MegatronDataset.__init__ function.

BlendedDataset#

The BlendedDataset is parameterized by the following variables: the underlying MegatronDataset instances D, the weights W (one per dataset), and the size S. The BlendedDataset will draw samples from contributing datasets in proportion to the weights until achieving a composite dataset of the desired size. During each sampling step, we draw a single sample from the dataset which has the greatest sampling error.

The BlendedDataset creates two “blending” indices to facilitate lookup: (1) the dataset index and (2) the dataset sample index.

  1. The dataset index Da_idx is a 1-D array mapping from i to dataset index of length S.

    Given
    
    D = [d0, d1, d2]
    W = [1/2, 1/4, 1/4]
    S = 4
    
    Then, for example:
    
    Da_idx = [0, 1, 2, 0]
    
  2. The dataset sample index Sa_idx is a 1-D mapping from i to the sample index for dataset Da_idx[i] of length S.

    Given
    
    Da_idx = [0, 1, 2, 0]
    
    Then, for example:
    
    Sa_idx = [0, 0, 0, 1]
    

To query the BlendedDataset for the k-th sample we do the following

  • Use the dataset index to retrieve the corresponding dataset from D and the dataset sample index to retrieve the corresponding sample from that dataset.

    sample = D[Da_idx[k]][Sa_idx[k]]
    

To save time during initialization, each index is built/cached sequentially on one process rank and subsequently loaded in parallel on other process ranks. The cached indices are unique to a hash generated in the BlendedDataset.__init__ function.

Submodules#

datasets.blended_megatron_dataset_config module#

class core.datasets.blended_megatron_dataset_config.BlendedMegatronDatasetConfig(
random_seed: int,
sequence_length: int,
blend: Tuple[List[str], List[float] | None] | None = None,
blend_per_split: List[Tuple[List[str], List[float] | None] | None] | None = None,
multiple_validation_sets: bool | None = None,
full_validation: bool | None = None,
split: str | None = None,
num_dataset_builder_threads: int = 1,
path_to_cache: str | None = None,
mmap_bin_files: bool = True,
tokenizer: megatron.core.tokenizers.MegatronTokenizerBase | None = None,
mid_level_dataset_surplus: float = 0.005,
)#

Bases: object

Configuration object for Megatron Core datasets

blend: Tuple[List[str], List[float] | None] | None = None#

The blend, consisting of a list of dataset prefixes and optionally a list of dataset weights. For example, [[“dataset-path1”, “dataset-path2”], [0.3, 0.7]]. When the weights are None, they are inferred from the lengths of the contributing datasets. Not to be used with ‘blend_per_split’. Defaults to None.

blend_per_split: List[Tuple[List[str], List[float] | None] | None] | None = None#

A set of blends, as defined above, one for each split distribution. Not to be used with ‘blend’. Defauls to None.

full_validation: bool | None = None#

Whether to run a full epoch of validation each time validation occurs.

mid_level_dataset_surplus: float = 0.005#

The sample surplus to build for the mid-level datasets(s). Defaults arbitrarily to 0.005. This value is irrelevant for single source data blends. This value may need to be increased if the top level dataset oversamples the mid level dataset(s). This value may be set to 0.0 in future if the top level dataset is constrained to not oversample the mid level datasets(s).

mmap_bin_files: bool = True#

Whether to mmap the .bin files or use file pointers.

mock: bool = False#

Whether to bypass real data loading and validation in favor of mock data generation. Created automatically from ‘blend’ and ‘blend_per_split’. Not to be passed in to the constructor.

multiple_validation_sets: bool | None = None#

Whether the validation split should be treated as multiple seperate datasets.

num_dataset_builder_threads: int = 1#

The number of threads to use for dataset building.

path_to_cache: str | None = None#

Where all re-useable dataset indices are to be cached.

random_seed: int#

The seed for all RNG during dataset creation.

sequence_length: int#

The sequence length.

split: str | None = None#

The split string, a comma separated weighting for the dataset splits when drawing samples from a single distribution. Not to be used with ‘blend_per_split’. Defaults to None.

split_matrix: List[Tuple[float, float]] | None = None#

The split matrix consisting of non-overlapping book-ends of each split in order. For more information, refer to ‘convert_split_vector_to_split_matrix’. Created automatically from ‘split’. Not to be passed in to the constructor.

tokenizer: megatron.core.tokenizers.MegatronTokenizerBase | None = None#

The MegatronTokenizerBase instance. Required for datasets that do online tokenization.

core.datasets.blended_megatron_dataset_config.convert_split_vector_to_split_matrix(
vector_a: List[float],
vector_b: List[float] | None = None,
) List[Tuple[float, float] | None]#

Build the split matrix from one or optionally two contributing split vectors.

Ex. a standard conversion:

[0.99, 0.01, 0.0] -> [(0, 0.99), (0.99, 1.0), None]

Ex. a conversion for Retro when Retro pretraining uses a [0.99, 0.01, 0.0] split and Retro preprocessing used a [0.98, 0.02, 0.0] split:

[0.99, 0.01, 0.0], [0.98, 0.02, 0.0] -> [(0, 0.98), (0.99, 1.0), None]

Parameters:
  • vector_a (List[float]) – The primary split vector

  • vector_b (Optional[List[float]]) – An optional secondary split vector which constrains the primary split vector. Defaults to None.

Returns:

The split matrix consisting of book-ends of each split in order

Return type:

List[Tuple[float, float]]

core.datasets.blended_megatron_dataset_config.parse_and_normalize_split(split: str) List[float]#

Parse the dataset split ratios from a string

Parameters:

split (str) – The train valid test split string e.g. “99,1,0”

Returns:

The trian valid test split ratios e.g. [0.99, 0.01, 0.0]

Return type:

List[float]

datasets.blended_megatron_dataset_builder module#

class core.datasets.blended_megatron_dataset_builder.BlendedMegatronDatasetBuilder(
cls: Type[megatron.core.datasets.megatron_dataset.MegatronDataset],
sizes: List[int],
is_built_on_rank: Callable,
config: megatron.core.datasets.blended_megatron_dataset_config.BlendedMegatronDatasetConfig,
)#

Bases: object

Builder class for the BlendedDataset and MegatronDataset classes

Parameters:
  • cls (Type[MegatronDataset]) – The class to instantiate, must inherit from MegatronDataset

  • sizes (List[Optional[int]]) – The minimum total number of samples to draw, or None, per split

  • is_built_on_rank (Callable) – A callable which returns True if the dataset should be built on the current rank and False otherwise. It should be Megatron Core parallelism aware i.e. global rank, local group rank, and virtual rank may inform its return value. Should return true for exactly one process on global rank 0.

  • config (BlendedMegatronDatasetConfig) – The config object which informs dataset creation

build() List[megatron.core.datasets.blended_dataset.BlendedDataset | megatron.core.datasets.megatron_dataset.MegatronDataset | None]#

Build all dataset splits according to the provided blend(s)

This method is distributed-aware and must be called on all ranks.

The dataset splits returned can vary according to the config. Supply config.blend and config.split to build BlendedDataset and/or MegatronDataset splits from the same distribution. Supply config.blend_per_split to build BlendedDataset and/or MegatronDataset splits from separate distributions. In either case, for each split, handle the following cases:

  1. The split is None - do nothing

  2. The split has one contributing dataset, and…

    1. ‘size’ is not None - Build a mid-level dataset with low-level dataset sampling in proportion to the size

    2. ‘size’ is None - Build mid-level datasets with no excess low-level dataset sampling

  3. The split has multiple contributing datasets, and…

    1. ‘weights’ is not None and ‘size’ is not None - Build mid-level datasets with low-level dataset sampling in proportion to their weights and the size - Build a top-level dataset of length marginally greater than ‘size’ with mid-level dataset sampling in proportion to their weights and the size

    2. ‘weights’ is not None and ‘size’ is None - Error

    3. ‘weights’ is None and ‘size’ is not None - Build mid-level datasets with no excess low-level dataset sampling - Build a top-level dataset of length ‘size’ (capped at the sum of the mid-level dataset lengths) with mid-level dataset sampling in proportion to their lengths and the size

    4. ‘weights’ is None and ‘size’ is None - Build mid-level datasets with no excess low-level dataset sampling - Build a top-level dataset with no excess mid-level dataset sampling

Returns:

A list containing a dataset instance (or None) per

split

Return type:

List[Optional[TopLevelDataset]]

static build_generic_dataset(
cls: Type[megatron.core.datasets.blended_dataset.BlendedDataset | megatron.core.datasets.megatron_dataset.MegatronDataset | megatron.core.datasets.megatron_dataset.LowLevelDataset | torch.utils.data.Dataset] | Callable,
is_built_on_rank: Callable,
synchronize_ranks: bool,
*args: Any,
) megatron.core.datasets.blended_dataset.BlendedDataset | megatron.core.datasets.megatron_dataset.MegatronDataset | megatron.core.datasets.megatron_dataset.LowLevelDataset | torch.utils.data.Dataset | Iterable | None#

Build the DistributedDataset

Return None if and only if the underlying dataset class is not built on the current rank and torch.distributed is initialized.

Parameters:
  • cls (Union[Type[DistributedDataset], Callable]) – The DistributedDataset class to be built. In special cases, e.g. when we are building the low level dataset for a RawMegatronDataset instance, we can accept a Callable which returns an Iterable.

  • synchronize_ranks (bool) – Whether to call barrier for rank-0 / barrier / other-ranks behavior. Set to False when we enforce this behavior at higher level.

  • args (Tuple[Any]) – The positional arguments used to build the provided DistributedDataset class

Raises:

Exception – When the dataset constructor raises an OSError

Returns:

The DistributedDataset instantion, the

Iterable instantiation, or None

Return type:

Optional[Union[DistributedDataset, Iterable]]

datasets.megatron_tokenizer module#

class core.datasets.megatron_tokenizer.MegatronLegacyTokenizer(
*tokenizer_paths: str,
**tokenizer_options: Any,
)#

Bases: ABC

Abstract class for tokenizer

Absent a config or class-specific tracking of which objects are uniquely identifying, we must include all key word arguments as unique identifiers

Parameters:
  • tokenizer_paths (Tuple[str]) – All tokenizer source paths or prefixes

  • tokenizer_options (Dict[str, Any]) – All tokenizer options

property bos#

The BOS token id

Raises:

NotImplementedError – Non-abstract, optional attribute

property cls#

The CLS token id

Raises:

NotImplementedError – Non-abstract, optional attribute

detokenize(ids: numpy.ndarray) str#

Convert embedding ids to text

Parameters:

ids (numpy.ndarray) – The ids to convert

Returns:

The converted text

Return type:

str

Raises:

NotImplementedError – Non-abstract, optional method

property eod#

The EOD token id

Raises:

NotImplementedError – Non-abstract, optional attribute

property eos#

The EOS token id

Raises:

NotImplementedError – Non-abstract, optional attribute

abstract property inv_vocab#

Dictionary from vocab id token to text token

property mask#

The MASK token id

Raises:

NotImplementedError – Non-abstract, optional attribute

offsets(
ids: list[int],
text: str,
) list[int]#

Convert embedding ids to text offsets

Parameters:
  • ids (list[int]) – The ids to convert

  • text (str) – The text to convert

Returns:

The converted offsets

Return type:

list[int]

Raises:

NotImplementedError – Non-abstract, optional method

property pad#

The PAD token id

Raises:

NotImplementedError – Non-abstract, optional attribute

property sep#

The SEP token id

Raises:

NotImplementedError – Non-abstract, optional attribute

abstract tokenize(text: str) numpy.ndarray#

Convert text to embedding ids

Parameters:

text (str) – The text to convert

Returns:

The converted embedding ids

Return type:

numpy.ndarray

abstract property vocab#

Dictionary from vocab text token to id token

abstract property vocab_size#

The vocabulary size

datasets.indexed_dataset module#

class core.datasets.indexed_dataset.DType(value)#

Bases: Enum

The NumPy data type Enum for writing/reading the IndexedDataset indices

classmethod code_from_dtype(value: Type[numpy.number]) int#

Get the code from the dtype

Parameters:

value (Type[numpy.number]) – The dtype

Returns:

The code

Return type:

int

classmethod dtype_from_code(value: int) Type[numpy.number]#

Get the dtype from the code

Parameters:

value (int) – The code

Returns:

The dtype

Return type:

Type[numpy.number]

float32 = 7#
float64 = 6#
int16 = 3#
int32 = 4#
int64 = 5#
int8 = 2#
static optimal_dtype(
cardinality: int | None,
) Type[numpy.number]#

Get the dtype to use for an index of a certain cardinality

Parameters:

cardinality (Optional[int]) – The number of elements to be indexed

Returns:

The dtype to use for the index

Return type:

Type[numpy.number]

static size(key: int | Type[numpy.number]) int#

Get the size of the dtype/code in bytes

Parameters:

key (Union[int, Type[numpy.number]]) – The dtype or code

Raises:

ValueError – If the key is neither dtype nor integer code

Returns:

The size of the dtype/code in in bytes

Return type:

int

uint16 = 8#
uint8 = 1#
class core.datasets.indexed_dataset.IndexedDataset(*args: Any, **kwargs: Any)#

Bases: Dataset

The low-level interface dataset class

Parameters:
  • path_prefix (str) – The index (.idx) and data (.bin) prefix

  • multimodal (bool) – Whether the dataset is multimodal. Defaults to False.

  • mmap (bool) – Whether to mmap the .bin files. Defaults to True.

  • object_storage_config (Optional[ObjectStorageConfig]) – Supplied only for data stored on S3 or MSC. IndexedDataset downloads the index (.idx) file to object_storage_config.path_to_idx_cache and streams data from the data (.bin) file in object_storage_config.bin_chunk_nbytes blocks. Note that mmap must be disabled for S3 data loading. Defaults to None.

property document_indices: numpy.ndarray#

Get the document indices

Returns:

The document indices

Return type:

numpy.ndarray

static exists(path_prefix: str) bool#

Return whether the IndexedDataset exists on disk at the prefix

Parameters:

path_prefix (str) – The prefix to the index (.idx) and data (.bin) files

Returns:

Whether the IndexedDataset exists on disk at the prefix

Return type:

bool

get(
idx: int,
offset: int = 0,
length: int | None = None,
) numpy.ndarray | Tuple[numpy.ndarray, numpy.number]#

Retrieve a single item from the dataset with the option to only return a portion of the item.

get(idx) is the same as [idx] but get() does not support slicing.

Parameters:
  • idx (Union[int, numpy.integer]) – The index into the dataset

  • offset (int) – The integer token offset in the sequence

  • length (int) – The number of tokens to grab from the sequence

Returns:

The sequence tokens and mode

at the index

Return type:

Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.number]]

get_document_indices() numpy.ndarray#

Get the document indices

This method is slated for deprecation.

Returns:

The document indices

Return type:

numpy.ndarray

initialize(
path_prefix: str,
multimodal: bool,
mmap: bool,
object_storage_config: megatron.core.datasets.object_storage_utils.ObjectStorageConfig | None,
) None#

Initialize the dataset

This method is called by IndexedDataset.__init__ during object creation and by IndexedDataset.__setstate__ during un-pickling

Parameters:
  • path_prefix (str) – The index (.idx) and data (.bin) prefix

  • multimodal (bool) – Whether the dataset is multimodal

  • mmap (bool) – Whether to mmap the .bin file

  • object_storage_config (Optional[ObjectStorageConfig]) – See IndexedDataset docstring for details.

property sequence_lengths: numpy.ndarray#

Get the sequence lengths

Returns:

The sequence lengths

Return type:

numpy.ndarray

property sequence_modes: numpy.ndarray#

Get the sequence modes

Returns:

The sequence modes

Return type:

numpy.ndarray

set_document_indices(
document_indices: numpy.ndarray,
) None#

Set the document indices

This method is slated for deprecation.

Parameters:

document_indices (numpy.ndarray) – The document indices

class core.datasets.indexed_dataset.IndexedDatasetBuilder(
bin_path: str,
dtype: Type[numpy.number] = numpy.int32,
multimodal: bool = False,
)#

Bases: object

Builder class for the IndexedDataset class

Parameters:
  • bin_path (str) – The path to the data (.bin) file

  • dtype (Type[numpy.number], optional) – The dtype of the index file. Defaults to numpy.int32.

  • multimodal (bool, optional) – Whether the dataset is multimodal. Defaults to False.

add_document(
tensor: torch.Tensor,
lengths: List[int],
modes: List[int] | None = None,
) None#

Add an entire document to the dataset

Parameters:
  • tensor (torch.Tensor) – The document to add

  • lengths (List[int]) – The lengths of each item in the document

  • modes (Optional[List[int]], optional) – The modes for each item in the document. Defaults to None.

add_index(path_prefix: str) None#

Add an entire IndexedDataset to the dataset

Parameters:

path_prefix (str) – The index (.idx) and data (.bin) prefix

add_item(
tensor: torch.Tensor,
mode: int = 0,
) None#

Add a single item to the dataset

Parameters:
  • tensor (torch.Tensor) – The item to add to the data file

  • mode (int, optional) – The mode for the item. Defaults to 0.

end_document() None#

Finalize the document, for use with IndexedDatasetBuilder.add_item

finalize(idx_path: str) None#

Clean up and write the index (.idx) file

Parameters:

idx_path (str) – The path to the index file

core.datasets.indexed_dataset.get_bin_path(path_prefix: str) str#

Get the path to the data file from the prefix

Parameters:

path_prefix (str) – The prefix

Returns:

The path to the data file

Return type:

str

core.datasets.indexed_dataset.get_idx_path(path_prefix: str) str#

Get the path to the index file from the prefix

Parameters:

path_prefix (str) – The prefix

Returns:

The path to the index file

Return type:

str

datasets.megatron_dataset module#

class core.datasets.megatron_dataset.MegatronDataset(*args: Any, **kwargs: Any)#

Bases: ABC, Dataset

The highest level wrapper class from which all dataset classes should inherit

Parameters:
  • dataset (LowLevelDataset) – The dataset around which to build the MegatronDataset

  • dataset_path (Optional[str]) – The real path on disk to the dataset, for bookkeeping

  • indices (numpy.ndarray) – The set of the documents indices to expose

  • num_samples (Optional[int]) – The minimum number of samples to build from the indexed dataset. When None, build as many samples as correspond to one epoch.

  • index_split (Split) – The indices Split

  • config (BlendedMegatronDatasetConfig) – The config

static build_low_level_dataset(
dataset_path: str,
config: megatron.core.datasets.blended_megatron_dataset_config.BlendedMegatronDatasetConfig,
) megatron.core.datasets.indexed_dataset.IndexedDataset | Iterable#

Build the low level dataset via a function to be called from within BlendedMegatronDatasetBuilder.build_generic_dataset

It may be that the low level dataset spans any subset of train/valid/test splits, which is why we define a static “build” function separately from the constructor in the mid level dataset class

Parameters:
Returns:

The low level dataset

Return type:

LowLevelDataset

static numel_low_level_dataset(
low_level_dataset: megatron.core.datasets.indexed_dataset.IndexedDataset | Iterable,
) int#

Return the number of elements in the underlying low level dataset for the purpose of segregating the train/valid/test split indices

It may be that the low level dataset can be split any number of ways, depending on the mid level dataset it supports, which is why we define the “number of elements” function separately from the __len__ function here in the mid level dataset class

Parameters:

low_level_dataset (LowLevelDataset) – The underlying low level dataset

Returns:

The number of elements in the underlying low level dataset

Return type:

int

datasets.gpt_dataset module#

class core.datasets.gpt_dataset.GPTDataset(*args: Any, **kwargs: Any)#

Bases: MegatronDataset

The base GPT dataset

Parameters:
  • indexed_dataset (IndexedDataset) – The IndexedDataset around which to build the GPTDataset

  • dataset_path (Optional[str]) – The real path on disk to the dataset, for bookkeeping

  • indexed_indices (numpy.ndarray) – The set of the documents indices to expose

  • num_samples (Optional[int]) – The number of samples to draw from the indexed dataset. When None, build as many samples as correspond to one epoch.

  • index_split (Split) – The indexed_indices Split

  • config (GPTDatasetConfig) – The config

static build_low_level_dataset(
dataset_path: str,
config: GPTDatasetConfig,
) megatron.core.datasets.indexed_dataset.IndexedDataset#

Abstract method implementation

Parameters:
  • dataset_path (str) – The real path prefix to the IndexedDataset .bin and .idx files

  • config (GPTDatasetConfig) – The config

Returns:

The underlying IndexedDataset

Return type:

IndexedDataset

static numel_low_level_dataset(
low_level_dataset: megatron.core.datasets.indexed_dataset.IndexedDataset,
) int#

Abstract method implementation

For GPT, the underlying IndexedDataset should be split by sequence, as opposed to, say, BERT, which should be split by document

Parameters:

low_level_dataset (IndexedDataset) – The underlying IndexedDataset

Returns:

The number of unique elements in the underlying IndexedDataset

Return type:

int

class core.datasets.gpt_dataset.GPTDatasetConfig(*args: Any, **kwargs: Any)#

Bases: BlendedMegatronDatasetConfig

Configuration object for Megatron Core GPT datasets

add_extra_token_to_sequence: bool = True#

Option to draw sequences with one extra token to ensure the sample input tokens and sample output tokens are both of the desired sequence length

create_attention_mask: bool = True#

Option to enable the attention masks generation. Can be disabled if attention kernel generates masks by itself.

drop_last_partial_validation_sequence: bool = True#

Option to drop the last partial validation sequence

eod_mask_loss: bool | None = None#

Option to enable the EOD mask loss

object_storage_cache_path: str | None = None#

Path for caching indices for s3 or msc dataloading.

reset_attention_mask: bool | None = None#

Option to reset the attention mask from the dataset

reset_position_ids: bool | None = None#

Option to reset the position IDs in the dataset at an interval

class core.datasets.gpt_dataset.MockGPTDataset(*args: Any, **kwargs: Any)#

Bases: GPTDataset

The mock GPT dataset

Parameters:
  • indexed_dataset (MockGPTLowLevelDataset) – The MockGPTLowLevelDataset around which to build the MockGPTDataset

  • dataset_path (Optional[str]) – This argument is of no consequence for the MockGPTDataset

  • indices (numpy.ndarray) – The set of the dataset indices to expose

  • num_samples (int) – The number of samples to draw from the dataset

  • index_split (Split) – The indices Split

  • config (GPTDatasetConfig) – The config

static build_low_level_dataset(
dataset_path: str | None,
config: GPTDatasetConfig,
) MockGPTLowLevelDataset#

Abstract method implementation

Parameters:
  • dataset_path (Optional[str]) – This argument is of no consequence for the MockGPTLowLevelDataset

  • config (GPTDatasetConfig) – The config

Returns:

The underlying MockGPTLowLevelDataset

Return type:

MockGPTLowLevelDataset

static numel_low_level_dataset(
low_level_dataset: MockGPTLowLevelDataset,
) int#

Abstract method implementation

Parameters:

low_level_dataset (MockGPTLowLevelDataset) – The underlying MockGPTLowLevelDataset

Returns:

The number of unique elements in the underlying MockGPTLowLevelDataset

Return type:

int

class core.datasets.gpt_dataset.MockGPTLowLevelDataset(
tokenizer: megatron.core.tokenizers.MegatronTokenizerBase,
)#

Bases: object

The mock GPT low level dataset

This class is meant to generate tokenized data in the classic “Megatron-LM” GPT style. Notably, we add the end of document token to each element indexed in __getitem__

Parameters:
  • tokenizer (MegatronTokenizerBase) – The tokenizer the special token information of which

  • data. (we use to augment the mock)

get(
idx: int,
offset: int = 0,
length: int | None = None,
) numpy.ndarray#

This function is n abstraction over __getitem__ with support for slicing

Parameters:
  • idx (int) – The index into the dataset

  • offset (int) – The integer token offset in the sequence

  • length (Optional[int]) – The number of tokens to grab from the sequence

Returns:

The sequence tokens at the index

Return type:

numpy.ndarray

max_sequence_length: int = 4096#

The hard-coded max sequence length to generate

seed: int = 0#

The hard-coded random seed to use to set the NumPy RNG

size: int = 100000#

The hard-coded number of samples to generate

datasets.masked_dataset module#

class core.datasets.masked_dataset.MaskedWordPieceDataset(*args: Any, **kwargs: Any)#

Bases: MegatronDataset

The semi-abstract base class for masked WordPiece datasets

This implementation makes the rigid assumption that all inheritor datasets are built upon the IndexedDataset class. This assumption may be pushed down to the inheritors in future if necessary.

NB: WordPiece tokenization prepends a double hash “##” to all tokens/pieces in a word, save the first token/piece.

Parameters:
  • indexed_dataset (IndexedDataset) – The IndexedDataset around which to build the MegatronDataset

  • dataset_path (str) – The real path on disk to the dataset, for bookkeeping

  • indexed_indices (numpy.ndarray) – The set of the documents indices to expose

  • num_samples (Optional[int]) – The number of samples to draw from the indexed dataset. When None, build as many samples as correspond to one epoch.

  • index_split (Split) – The indexed_indices Split

  • config (MaskedWordPieceDatasetConfig) – The config

static build_low_level_dataset(
dataset_path: str,
config: MaskedWordPieceDatasetConfig,
) megatron.core.datasets.indexed_dataset.IndexedDataset#
static numel_low_level_dataset(
low_level_dataset: megatron.core.datasets.indexed_dataset.IndexedDataset,
) int#
class core.datasets.masked_dataset.MaskedWordPieceDatasetConfig(*args: Any, **kwargs: Any)#

Bases: BlendedMegatronDatasetConfig

Configuration object for Megatron Core Masked WordPiece datasets

masking_do_full_word: bool = None#

Whether we mask the whole word or its component parts

masking_do_permutation: bool = None#

Whether we shuffle a subset of candidate N-grams in addition

masking_max_ngram: int = None#

The maximum length N-gram to consider masking or permuting

masking_probability: float = None#

The probability we mask a candidate N-gram

masking_use_geometric_distribution: bool = None#

Whether to draw the size of the N-gram from a geometric distribution according to SpanBERT https://arxiv.org/abs/1907.10529 (Section 3.1)

masking_use_longer_ngrams: bool = None#

Whether to favor longer N-grams over shorter N-grams

short_sequence_probability: float = None#

The probability we return a sequence shorter than the target sequence length

datasets.bert_dataset module#

class core.datasets.bert_dataset.BERTMaskedWordPieceDataset(*args: Any, **kwargs: Any)#

Bases: MaskedWordPieceDataset

The BERT dataset that assumes WordPiece tokenization

Parameters:
  • indexed_dataset (IndexedDataset) – The IndexedDataset around which to build the MegatronDataset

  • dataset_path (str) – The real path on disk to the dataset, for bookkeeping

  • indexed_indices (numpy.ndarray) – The set of the documents indices to expose

  • num_samples (Optional[int]) – The number of samples to draw from the indexed dataset. When None, build as many samples as correspond to one epoch.

  • index_split (Split) – The indexed_indices Split

  • config (BERTMaskedWordPieceDatasetConfig) – The config

class core.datasets.bert_dataset.BERTMaskedWordPieceDatasetConfig(*args: Any, **kwargs: Any)#

Bases: MaskedWordPieceDatasetConfig

Configuration object for Megatron Core BERT WordPiece datasets

classification_head: bool = None#

Option to perform the next sequence prediction during sampling

datasets.t5_dataset module#

class core.datasets.t5_dataset.T5MaskedWordPieceDataset(*args: Any, **kwargs: Any)#

Bases: MaskedWordPieceDataset

The T5 dataset that assumes WordPiece tokenization

Parameters:
  • indexed_dataset (IndexedDataset) – The IndexedDataset around which to build the MegatronDataset

  • dataset_path (str) – The real path on disk to the dataset, for bookkeeping

  • indexed_indices (numpy.ndarray) – The set of the documents indices to expose

  • num_samples (Optional[int]) – The number of samples to draw from the indexed dataset. When None, build as many samples as correspond to one epoch.

  • index_split (Split) – The indexed_indices Split

  • config (T5MaskedWordPieceDatasetConfig) – The config

static config_attention_mask(
encoder_tokens: torch.tensor,
decoder_tokens: torch.tensor,
encoder_mask: torch.tensor,
decoder_mask: torch.tensor,
use_local: bool = False,
test_te_version: str | None = None,
) torch.tensor#

Config attention-mask for encoder_mask, decoder_mask, encoder_decoder_mask conditioned on transformer-implementation (e.g. TE vs local), TE versions, and TE backends

Parameters:
  • encoder_tokens (torch.tensor) – A 2-D array of tokens (bs, kv_len)

  • decoder_tokens (torch.tensor) – A 2-D array of tokens (bs, q_len)

  • encoder_mask (torch.tensor) – A 2-D array of tokens (bs, kv_len)

  • decoder_mask (torch.tensor) – A 2-D array of tokens (bs, q_len)

  • use_local (bool) – Whether the current T5 model uses local (vs TE) transformer implmentation

Returns:

Configured encoder_mask, decoder_mask, encoder_decoder_mask torch.tensor: configured encoder attention mask torch.tensor: configured decoder attention mask torch.tensor: configured encoder-decoder attention mask

class core.datasets.t5_dataset.T5MaskedWordPieceDatasetConfig(*args: Any, **kwargs: Any)#

Bases: MaskedWordPieceDatasetConfig

Configuration object for Megatron Core T5 WordPiece datasets

NB: As a temporary holdover from Megatron-LM. The T5 tokenizer has an attribute which defines a number of special sentinel tokens used during sampling. The assert in __post_init__ serves to preserve compatibility with Megatron-LM until the T5 tokenizer is in Megatron Core.

sequence_length_decoder: int = None#

The sequence length for the decoder

sequence_length_encoder: int | None = None#

A sequence_length alias and the sequence length for the encoder

datasets.blended_dataset module#

class core.datasets.blended_dataset.BlendedDataset(*args: Any, **kwargs: Any)#

Bases: Dataset

Conjugating class for a set of MegatronDataset instances

Parameters:
  • datasets (List[MegatronDataset]) – The MegatronDataset instances to blend

  • weights (List[Union[int, float]]) – The weights that determine the dataset blend ratios

  • size (Optional[int]) – The number of samples to draw from the blend. If None, for each dataset index idx draw exactly weights[idx] samples from datasets[idx].

  • config (BlendedMegatronDatasetConfig) – The config

Raises:

RuntimeError – When the dataset has fewer or more samples than ‘size’ post-initialization

datasets.utils module#

class core.datasets.utils.Split(value)#

Bases: Enum

An enumeration.

test = 2#
train = 0#
valid = 1#
core.datasets.utils.compile_helpers()#

Compile C++ helper functions at runtime. Make sure this is invoked on a single process.

core.datasets.utils.get_blend_from_list(
blend: List[str] | None,
) Tuple[List[str], List[float] | None] | None#

Get the blended_megatron_dataset_config.BlendedMegatronDatasetConfig blend from the blend list

Parameters:

blend (Optional[List[str]]) – The blend list, which can be either (1) a list of prefixes, e.g. [“path/to/dataset_1_prefix”, “path/to/dataset_2_prefix”], or (2) a flattened, zipped list of weights and prefixes, e.g. [“30”, “path/to/dataset_1_prefix”, “70”, “path/to/dataset_2_prefix”]

Returns:

The blend, consisting of a list of dataset prefixes and optionally a list of dataset weights, e.g. [[“path/to/dataset_1_prefix”, “path/to/dataset_2_prefix”], [30.0, 70.0]].

Return type:

Optional[Tuple[List[str], Optional[List[float]]]]

core.datasets.utils.normalize(weights: List[float]) List[float]#

Do non-exponentiated normalization

Parameters:

weights (List[float]) – The weights

Returns:

The normalized weights

Return type:

List[float]

Module contents#