datasets package#
Data Pipeline#
Data pre-processing#
Data preprocessing is built around the following classes:
IndexedDatasetBuilder
IndexedDataset
At the moment, an end-to-end data preprocessing implementation is left to the user. See the class docstring(s) for more details.
IndexedDatasetBuilder#
The IndexedDatasetBuilder
is capable of building and merging IndexedDataset
instances.
IndexedDataset#
The IndexedDataset
class is the lowest-level data interface in Megatron Core. Internally, an IndexedDataset
instance references two binaries: the data file (.bin
) contains document/sequence data and the index file (.idx
) contains document/sequence metadata.
The index file stores dataset-level metadata first:
The index header, for backward compatibility
The index version, for backward compatibility
A numeric code corresponding to the data type used to write data to the data file
The number of sequences in the dataset
The number of documents in the dataset
The index file stores document-level and sequence-level metadata second:
In order, the number of elements per sequence
In order, the byte offset (pointer) per sequence
In order, the consecutive sequence index range
[...)
per documentIn order, the mode per sequence (in the multimodal case)
Data loading: construction#
Building the data loaders is a distributed-aware process built around the following classes:
BlendedMegatronDatasetConfig
BlendedMegatronDatasetBuilder
IndexedDataset
MegatronDataset
BlendedDataset
See the class docstrings for more details.
BlendedMegatronDatasetConfig (extendable)#
The BlendedMegatronDatasetConfig
class parameterizes the BlendedMegatronDatasetBuilder
and in turn the MegatronDataset
and BlendedDataset
.
Different training/inference regimes will require different extensions e.g. the GPTDatasetConfig
BlendedMegatronDatasetBuilder#
The BlendedMegatronDatasetBuilder
class builds the highest-level data interfaces in Megatron Core.
NB: All ranks should attempt to build the dataset via the BlendedMegatronDatasetBuilder
or the program will hang. Which ranks follow through on their attempts can be controlled via the BlendedMegatronDatasetConfig
.
IndexedDataset#
The IndexedDataset
class is the lowest-level data interface in Megatron Core.
The IndexedDataset
should already exist on disk before attempting to build any of the high-level data interfaces.
MegatronDataset (extendable)#
The MegatronDataset
abstract class is a high-level data interface in Megatron Core. It is an abstraction built upon the IndexedDataset
.
Different training/inference regimes will require different extensions e.g. the GPTDataset
BlendedDataset#
The BlendedDataset
class is a high-level data interface in Megatron Core. It is an abstraction built upon the MegatronDataset
.
The BlendedDataset
is only necessary when a blend multiple data distributions, i.e. multiple MegatronDataset
instances, should contribute to a certain dataset split. The blend can be controlled via the BlendedMegatronDatasetConfig
.
Data loading: implementation#
GPTDataset#
The GPTDataset
is parameterized by the following variables: the underlying IndexedDataset
instance indexed_dataset
, the split indices indexed_indices
(the congituous subset of document or sequence indices used for training, validation, and testing), the number of samples N
, the sequence length S
, and the random seed R
.
The GPTDataset
creates three index mappings to facilitate lookup: (1) the document index, (2) the sample index, and (3) the shuffle index.
The document index Do_idx is a 1-D array mapping from i to document index of length
E * |indexed_indices|
whereE
corresponds to the minimum number of epochs such thatE * |indexed_indices| >= N
. The document index is shuffled according toR
.Given: N = 15 indexed_indices = [5, 6, 7, 8, 9] E = 3 Then, for example: Do_idx = [8, 8, 9, 6, 7, 5, 8, 5, 6, 6, 5, 9, 7, 7, 9]
The sample index Sa_idx is a 2-D array mapping from j to pairs of (i, Do_idx[ i ] offset) of shape
[N + 1, 2]
. The rows j and j + 1 serve as the left and right bounds for the j-th sample.Given: S = 1024 Then, for example: Sa_idx[0] = (0, 0) Sa_idx[1] = (0, 1024) => Do_idx[0] has length greater than S Sa_idx[2] = (1, 512) => Do_idx[0] has length 1536 Sa_idx[3] = (2, 0) => Do_idx[1] has length 1536 Sa_idx[4] = (5, 300) => Do_idx[2:5] are shorter documents relative to Do_idx[0:2] Sa_idx[5] = (6, 24) => Do_idx[5] has length 1300
The shuffle index Sh_idx is a 1-D array mapping from k to j of length
N
. The shuffle index is shuffled according toR
.Given N = 10 Then, for example: Sh_idx = [4, 0, 2, 6, 1, 9, 5, 8, 7, 3]
To query the GPTDataset
for the k-th sample we do the following
Use the shuffle index to get the index j into the sample index.
j = Sh_idx[k]
Use the sample index to get the left and right sample-bounding indices into the document index and the starting token offset for each document.
i, offset = Sa_idx[j] i_next, offset_next = Sa_idx[j + 1]
Use the document index to retrieve
S
tokens from consecutive (in the document index) documents.sample = [] sample += indexed_dataset[Do_idx[i]][offset:] if i != i_next: sample += indexed_dataset[Do_idx[i + 1:i_next]] sample += indexed_dataset[Do_idx[i_next]][:offset_next]
To save time during initialization, each index is built/cached sequentially on one process rank and subsequently loaded in parallel on other process ranks. The cached indices are unique to a hash generated in the MegatronDataset.__init__
function.
BlendedDataset#
The BlendedDataset
is parameterized by the following variables: the underlying MegatronDataset
instances D
, the weights W
(one per dataset), and the size S
. The BlendedDataset
will draw samples from contributing datasets in proportion to the weights until achieving a composite dataset of the desired size. During each sampling step, we draw a single sample from the dataset which has the greatest sampling error.
The BlendedDataset
creates two “blending” indices to facilitate lookup: (1) the dataset index and (2) the dataset sample index.
The dataset index Da_idx is a 1-D array mapping from i to dataset index of length
S
.Given D = [d0, d1, d2] W = [1/2, 1/4, 1/4] S = 4 Then, for example: Da_idx = [0, 1, 2, 0]
The dataset sample index Sa_idx is a 1-D mapping from i to the sample index for dataset Da_idx[i] of length
S
.Given Da_idx = [0, 1, 2, 0] Then, for example: Sa_idx = [0, 0, 0, 1]
To query the BlendedDataset
for the k-th sample we do the following
Use the dataset index to retrieve the corresponding dataset from
D
and the dataset sample index to retrieve the corresponding sample from that dataset.sample = D[Da_idx[k]][Sa_idx[k]]
To save time during initialization, each index is built/cached sequentially on one process rank and subsequently loaded in parallel on other process ranks. The cached indices are unique to a hash generated in the BlendedDataset.__init__
function.
Submodules#
datasets.blended_megatron_dataset_config module#
- class core.datasets.blended_megatron_dataset_config.BlendedMegatronDatasetConfig(
- random_seed: int,
- sequence_length: int,
- blend: Tuple[List[str], List[float] | None] | None = None,
- blend_per_split: List[Tuple[List[str], List[float] | None] | None] | None = None,
- multiple_validation_sets: bool | None = None,
- full_validation: bool | None = None,
- split: str | None = None,
- num_dataset_builder_threads: int = 1,
- path_to_cache: str | None = None,
- mmap_bin_files: bool = True,
- tokenizer: megatron.core.tokenizers.MegatronTokenizerBase | None = None,
- mid_level_dataset_surplus: float = 0.005,
Bases:
object
Configuration object for Megatron Core datasets
- blend: Tuple[List[str], List[float] | None] | None = None#
The blend, consisting of a list of dataset prefixes and optionally a list of dataset weights. For example, [[“dataset-path1”, “dataset-path2”], [0.3, 0.7]]. When the weights are None, they are inferred from the lengths of the contributing datasets. Not to be used with ‘blend_per_split’. Defaults to None.
- blend_per_split: List[Tuple[List[str], List[float] | None] | None] | None = None#
A set of blends, as defined above, one for each split distribution. Not to be used with ‘blend’. Defauls to None.
- full_validation: bool | None = None#
Whether to run a full epoch of validation each time validation occurs.
- mid_level_dataset_surplus: float = 0.005#
The sample surplus to build for the mid-level datasets(s). Defaults arbitrarily to 0.005. This value is irrelevant for single source data blends. This value may need to be increased if the top level dataset oversamples the mid level dataset(s). This value may be set to 0.0 in future if the top level dataset is constrained to not oversample the mid level datasets(s).
- mmap_bin_files: bool = True#
Whether to mmap the .bin files or use file pointers.
- mock: bool = False#
Whether to bypass real data loading and validation in favor of mock data generation. Created automatically from ‘blend’ and ‘blend_per_split’. Not to be passed in to the constructor.
- multiple_validation_sets: bool | None = None#
Whether the validation split should be treated as multiple seperate datasets.
- num_dataset_builder_threads: int = 1#
The number of threads to use for dataset building.
- path_to_cache: str | None = None#
Where all re-useable dataset indices are to be cached.
- random_seed: int#
The seed for all RNG during dataset creation.
- sequence_length: int#
The sequence length.
- split: str | None = None#
The split string, a comma separated weighting for the dataset splits when drawing samples from a single distribution. Not to be used with ‘blend_per_split’. Defaults to None.
- split_matrix: List[Tuple[float, float]] | None = None#
The split matrix consisting of non-overlapping book-ends of each split in order. For more information, refer to ‘convert_split_vector_to_split_matrix’. Created automatically from ‘split’. Not to be passed in to the constructor.
- tokenizer: megatron.core.tokenizers.MegatronTokenizerBase | None = None#
The MegatronTokenizerBase instance. Required for datasets that do online tokenization.
- core.datasets.blended_megatron_dataset_config.convert_split_vector_to_split_matrix(
- vector_a: List[float],
- vector_b: List[float] | None = None,
Build the split matrix from one or optionally two contributing split vectors.
Ex. a standard conversion:
[0.99, 0.01, 0.0] -> [(0, 0.99), (0.99, 1.0), None]
Ex. a conversion for Retro when Retro pretraining uses a [0.99, 0.01, 0.0] split and Retro preprocessing used a [0.98, 0.02, 0.0] split:
[0.99, 0.01, 0.0], [0.98, 0.02, 0.0] -> [(0, 0.98), (0.99, 1.0), None]
- Parameters:
vector_a (List[float]) – The primary split vector
vector_b (Optional[List[float]]) – An optional secondary split vector which constrains the primary split vector. Defaults to None.
- Returns:
The split matrix consisting of book-ends of each split in order
- Return type:
List[Tuple[float, float]]
- core.datasets.blended_megatron_dataset_config.parse_and_normalize_split(split: str) List[float] #
Parse the dataset split ratios from a string
- Parameters:
split (str) – The train valid test split string e.g. “99,1,0”
- Returns:
The trian valid test split ratios e.g. [0.99, 0.01, 0.0]
- Return type:
List[float]
datasets.blended_megatron_dataset_builder module#
- class core.datasets.blended_megatron_dataset_builder.BlendedMegatronDatasetBuilder(
- cls: Type[megatron.core.datasets.megatron_dataset.MegatronDataset],
- sizes: List[int],
- is_built_on_rank: Callable,
- config: megatron.core.datasets.blended_megatron_dataset_config.BlendedMegatronDatasetConfig,
Bases:
object
Builder class for the BlendedDataset and MegatronDataset classes
- Parameters:
cls (Type[MegatronDataset]) – The class to instantiate, must inherit from MegatronDataset
sizes (List[Optional[int]]) – The minimum total number of samples to draw, or None, per split
is_built_on_rank (Callable) – A callable which returns True if the dataset should be built on the current rank and False otherwise. It should be Megatron Core parallelism aware i.e. global rank, local group rank, and virtual rank may inform its return value. Should return true for exactly one process on global rank 0.
config (BlendedMegatronDatasetConfig) – The config object which informs dataset creation
- build() List[megatron.core.datasets.blended_dataset.BlendedDataset | megatron.core.datasets.megatron_dataset.MegatronDataset | None] #
Build all dataset splits according to the provided blend(s)
This method is distributed-aware and must be called on all ranks.
The dataset splits returned can vary according to the config. Supply config.blend and config.split to build BlendedDataset and/or MegatronDataset splits from the same distribution. Supply config.blend_per_split to build BlendedDataset and/or MegatronDataset splits from separate distributions. In either case, for each split, handle the following cases:
The split is None - do nothing
The split has one contributing dataset, and…
‘size’ is not None - Build a mid-level dataset with low-level dataset sampling in proportion to the size
‘size’ is None - Build mid-level datasets with no excess low-level dataset sampling
The split has multiple contributing datasets, and…
‘weights’ is not None and ‘size’ is not None - Build mid-level datasets with low-level dataset sampling in proportion to their weights and the size - Build a top-level dataset of length marginally greater than ‘size’ with mid-level dataset sampling in proportion to their weights and the size
‘weights’ is not None and ‘size’ is None - Error
‘weights’ is None and ‘size’ is not None - Build mid-level datasets with no excess low-level dataset sampling - Build a top-level dataset of length ‘size’ (capped at the sum of the mid-level dataset lengths) with mid-level dataset sampling in proportion to their lengths and the size
‘weights’ is None and ‘size’ is None - Build mid-level datasets with no excess low-level dataset sampling - Build a top-level dataset with no excess mid-level dataset sampling
- Returns:
- A list containing a dataset instance (or None) per
split
- Return type:
List[Optional[TopLevelDataset]]
- static build_generic_dataset(
- cls: Type[megatron.core.datasets.blended_dataset.BlendedDataset | megatron.core.datasets.megatron_dataset.MegatronDataset | megatron.core.datasets.megatron_dataset.LowLevelDataset | torch.utils.data.Dataset] | Callable,
- is_built_on_rank: Callable,
- synchronize_ranks: bool,
- *args: Any,
Build the DistributedDataset
Return None if and only if the underlying dataset class is not built on the current rank and torch.distributed is initialized.
- Parameters:
cls (Union[Type[DistributedDataset], Callable]) – The DistributedDataset class to be built. In special cases, e.g. when we are building the low level dataset for a RawMegatronDataset instance, we can accept a Callable which returns an Iterable.
synchronize_ranks (bool) – Whether to call barrier for rank-0 / barrier / other-ranks behavior. Set to False when we enforce this behavior at higher level.
args (Tuple[Any]) – The positional arguments used to build the provided DistributedDataset class
- Raises:
Exception – When the dataset constructor raises an OSError
- Returns:
- The DistributedDataset instantion, the
Iterable instantiation, or None
- Return type:
Optional[Union[DistributedDataset, Iterable]]
datasets.megatron_tokenizer module#
- class core.datasets.megatron_tokenizer.MegatronLegacyTokenizer(
- *tokenizer_paths: str,
- **tokenizer_options: Any,
Bases:
ABC
Abstract class for tokenizer
Absent a config or class-specific tracking of which objects are uniquely identifying, we must include all key word arguments as unique identifiers
- Parameters:
tokenizer_paths (Tuple[str]) – All tokenizer source paths or prefixes
tokenizer_options (Dict[str, Any]) – All tokenizer options
- property bos#
The BOS token id
- Raises:
NotImplementedError – Non-abstract, optional attribute
- property cls#
The CLS token id
- Raises:
NotImplementedError – Non-abstract, optional attribute
- detokenize(ids: numpy.ndarray) str #
Convert embedding ids to text
- Parameters:
ids (numpy.ndarray) – The ids to convert
- Returns:
The converted text
- Return type:
str
- Raises:
NotImplementedError – Non-abstract, optional method
- property eod#
The EOD token id
- Raises:
NotImplementedError – Non-abstract, optional attribute
- property eos#
The EOS token id
- Raises:
NotImplementedError – Non-abstract, optional attribute
- abstract property inv_vocab#
Dictionary from vocab id token to text token
- property mask#
The MASK token id
- Raises:
NotImplementedError – Non-abstract, optional attribute
- offsets(
- ids: list[int],
- text: str,
Convert embedding ids to text offsets
- Parameters:
ids (list[int]) – The ids to convert
text (str) – The text to convert
- Returns:
The converted offsets
- Return type:
list[int]
- Raises:
NotImplementedError – Non-abstract, optional method
- property pad#
The PAD token id
- Raises:
NotImplementedError – Non-abstract, optional attribute
- property sep#
The SEP token id
- Raises:
NotImplementedError – Non-abstract, optional attribute
- abstract tokenize(text: str) numpy.ndarray #
Convert text to embedding ids
- Parameters:
text (str) – The text to convert
- Returns:
The converted embedding ids
- Return type:
numpy.ndarray
- abstract property vocab#
Dictionary from vocab text token to id token
- abstract property vocab_size#
The vocabulary size
datasets.indexed_dataset module#
- class core.datasets.indexed_dataset.DType(value)#
Bases:
Enum
The NumPy data type Enum for writing/reading the IndexedDataset indices
- classmethod code_from_dtype(value: Type[numpy.number]) int #
Get the code from the dtype
- Parameters:
value (Type[numpy.number]) – The dtype
- Returns:
The code
- Return type:
int
- classmethod dtype_from_code(value: int) Type[numpy.number] #
Get the dtype from the code
- Parameters:
value (int) – The code
- Returns:
The dtype
- Return type:
Type[numpy.number]
- float32 = 7#
- float64 = 6#
- int16 = 3#
- int32 = 4#
- int64 = 5#
- int8 = 2#
- static optimal_dtype(
- cardinality: int | None,
Get the dtype to use for an index of a certain cardinality
- Parameters:
cardinality (Optional[int]) – The number of elements to be indexed
- Returns:
The dtype to use for the index
- Return type:
Type[numpy.number]
- static size(key: int | Type[numpy.number]) int #
Get the size of the dtype/code in bytes
- Parameters:
key (Union[int, Type[numpy.number]]) – The dtype or code
- Raises:
ValueError – If the key is neither dtype nor integer code
- Returns:
The size of the dtype/code in in bytes
- Return type:
int
- uint16 = 8#
- uint8 = 1#
- class core.datasets.indexed_dataset.IndexedDataset(*args: Any, **kwargs: Any)#
Bases:
Dataset
The low-level interface dataset class
- Parameters:
path_prefix (str) – The index (.idx) and data (.bin) prefix
multimodal (bool) – Whether the dataset is multimodal. Defaults to False.
mmap (bool) – Whether to mmap the .bin files. Defaults to True.
object_storage_config (Optional[ObjectStorageConfig]) – Supplied only for data stored on S3 or MSC. IndexedDataset downloads the index (.idx) file to object_storage_config.path_to_idx_cache and streams data from the data (.bin) file in object_storage_config.bin_chunk_nbytes blocks. Note that mmap must be disabled for S3 data loading. Defaults to None.
- property document_indices: numpy.ndarray#
Get the document indices
- Returns:
The document indices
- Return type:
numpy.ndarray
- static exists(path_prefix: str) bool #
Return whether the IndexedDataset exists on disk at the prefix
- Parameters:
path_prefix (str) – The prefix to the index (.idx) and data (.bin) files
- Returns:
Whether the IndexedDataset exists on disk at the prefix
- Return type:
bool
- get(
- idx: int,
- offset: int = 0,
- length: int | None = None,
Retrieve a single item from the dataset with the option to only return a portion of the item.
get(idx) is the same as [idx] but get() does not support slicing.
- Parameters:
idx (Union[int, numpy.integer]) – The index into the dataset
offset (int) – The integer token offset in the sequence
length (int) – The number of tokens to grab from the sequence
- Returns:
- The sequence tokens and mode
at the index
- Return type:
Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.number]]
- get_document_indices() numpy.ndarray #
Get the document indices
This method is slated for deprecation.
- Returns:
The document indices
- Return type:
numpy.ndarray
- initialize(
- path_prefix: str,
- multimodal: bool,
- mmap: bool,
- object_storage_config: megatron.core.datasets.object_storage_utils.ObjectStorageConfig | None,
Initialize the dataset
This method is called by IndexedDataset.__init__ during object creation and by IndexedDataset.__setstate__ during un-pickling
- Parameters:
path_prefix (str) – The index (.idx) and data (.bin) prefix
multimodal (bool) – Whether the dataset is multimodal
mmap (bool) – Whether to mmap the .bin file
object_storage_config (Optional[ObjectStorageConfig]) – See IndexedDataset docstring for details.
- property sequence_lengths: numpy.ndarray#
Get the sequence lengths
- Returns:
The sequence lengths
- Return type:
numpy.ndarray
- property sequence_modes: numpy.ndarray#
Get the sequence modes
- Returns:
The sequence modes
- Return type:
numpy.ndarray
- set_document_indices(
- document_indices: numpy.ndarray,
Set the document indices
This method is slated for deprecation.
- Parameters:
document_indices (numpy.ndarray) – The document indices
- class core.datasets.indexed_dataset.IndexedDatasetBuilder(
- bin_path: str,
- dtype: Type[numpy.number] = numpy.int32,
- multimodal: bool = False,
Bases:
object
Builder class for the IndexedDataset class
- Parameters:
bin_path (str) – The path to the data (.bin) file
dtype (Type[numpy.number], optional) – The dtype of the index file. Defaults to numpy.int32.
multimodal (bool, optional) – Whether the dataset is multimodal. Defaults to False.
- add_document(
- tensor: torch.Tensor,
- lengths: List[int],
- modes: List[int] | None = None,
Add an entire document to the dataset
- Parameters:
tensor (torch.Tensor) – The document to add
lengths (List[int]) – The lengths of each item in the document
modes (Optional[List[int]], optional) – The modes for each item in the document. Defaults to None.
- add_index(path_prefix: str) None #
Add an entire IndexedDataset to the dataset
- Parameters:
path_prefix (str) – The index (.idx) and data (.bin) prefix
- add_item(
- tensor: torch.Tensor,
- mode: int = 0,
Add a single item to the dataset
- Parameters:
tensor (torch.Tensor) – The item to add to the data file
mode (int, optional) – The mode for the item. Defaults to 0.
- end_document() None #
Finalize the document, for use with IndexedDatasetBuilder.add_item
- finalize(idx_path: str) None #
Clean up and write the index (.idx) file
- Parameters:
idx_path (str) – The path to the index file
- core.datasets.indexed_dataset.get_bin_path(path_prefix: str) str #
Get the path to the data file from the prefix
- Parameters:
path_prefix (str) – The prefix
- Returns:
The path to the data file
- Return type:
str
- core.datasets.indexed_dataset.get_idx_path(path_prefix: str) str #
Get the path to the index file from the prefix
- Parameters:
path_prefix (str) – The prefix
- Returns:
The path to the index file
- Return type:
str
datasets.megatron_dataset module#
- class core.datasets.megatron_dataset.MegatronDataset(*args: Any, **kwargs: Any)#
Bases:
ABC
,Dataset
The highest level wrapper class from which all dataset classes should inherit
- Parameters:
dataset (LowLevelDataset) – The dataset around which to build the MegatronDataset
dataset_path (Optional[str]) – The real path on disk to the dataset, for bookkeeping
indices (numpy.ndarray) – The set of the documents indices to expose
num_samples (Optional[int]) – The minimum number of samples to build from the indexed dataset. When None, build as many samples as correspond to one epoch.
index_split (Split) – The indices Split
config (BlendedMegatronDatasetConfig) – The config
- static build_low_level_dataset(
- dataset_path: str,
- config: megatron.core.datasets.blended_megatron_dataset_config.BlendedMegatronDatasetConfig,
Build the low level dataset via a function to be called from within BlendedMegatronDatasetBuilder.build_generic_dataset
It may be that the low level dataset spans any subset of train/valid/test splits, which is why we define a static “build” function separately from the constructor in the mid level dataset class
- Parameters:
dataset_path (str) – The real path on disk to the dataset
config (BlendedMegatronDatasetConfig) – The dataset config
- Returns:
The low level dataset
- Return type:
LowLevelDataset
- static numel_low_level_dataset(
- low_level_dataset: megatron.core.datasets.indexed_dataset.IndexedDataset | Iterable,
Return the number of elements in the underlying low level dataset for the purpose of segregating the train/valid/test split indices
It may be that the low level dataset can be split any number of ways, depending on the mid level dataset it supports, which is why we define the “number of elements” function separately from the __len__ function here in the mid level dataset class
- Parameters:
low_level_dataset (LowLevelDataset) – The underlying low level dataset
- Returns:
The number of elements in the underlying low level dataset
- Return type:
int
datasets.gpt_dataset module#
- class core.datasets.gpt_dataset.GPTDataset(*args: Any, **kwargs: Any)#
Bases:
MegatronDataset
The base GPT dataset
- Parameters:
indexed_dataset (IndexedDataset) – The IndexedDataset around which to build the GPTDataset
dataset_path (Optional[str]) – The real path on disk to the dataset, for bookkeeping
indexed_indices (numpy.ndarray) – The set of the documents indices to expose
num_samples (Optional[int]) – The number of samples to draw from the indexed dataset. When None, build as many samples as correspond to one epoch.
index_split (Split) – The indexed_indices Split
config (GPTDatasetConfig) – The config
- static build_low_level_dataset(
- dataset_path: str,
- config: GPTDatasetConfig,
Abstract method implementation
- Parameters:
dataset_path (str) – The real path prefix to the IndexedDataset .bin and .idx files
config (GPTDatasetConfig) – The config
- Returns:
The underlying IndexedDataset
- Return type:
- static numel_low_level_dataset(
- low_level_dataset: megatron.core.datasets.indexed_dataset.IndexedDataset,
Abstract method implementation
For GPT, the underlying IndexedDataset should be split by sequence, as opposed to, say, BERT, which should be split by document
- Parameters:
low_level_dataset (IndexedDataset) – The underlying IndexedDataset
- Returns:
The number of unique elements in the underlying IndexedDataset
- Return type:
int
- class core.datasets.gpt_dataset.GPTDatasetConfig(*args: Any, **kwargs: Any)#
Bases:
BlendedMegatronDatasetConfig
Configuration object for Megatron Core GPT datasets
- add_extra_token_to_sequence: bool = True#
Option to draw sequences with one extra token to ensure the sample input tokens and sample output tokens are both of the desired sequence length
- create_attention_mask: bool = True#
Option to enable the attention masks generation. Can be disabled if attention kernel generates masks by itself.
- drop_last_partial_validation_sequence: bool = True#
Option to drop the last partial validation sequence
- eod_mask_loss: bool | None = None#
Option to enable the EOD mask loss
- object_storage_cache_path: str | None = None#
Path for caching indices for s3 or msc dataloading.
- reset_attention_mask: bool | None = None#
Option to reset the attention mask from the dataset
- reset_position_ids: bool | None = None#
Option to reset the position IDs in the dataset at an interval
- class core.datasets.gpt_dataset.MockGPTDataset(*args: Any, **kwargs: Any)#
Bases:
GPTDataset
The mock GPT dataset
- Parameters:
indexed_dataset (MockGPTLowLevelDataset) – The MockGPTLowLevelDataset around which to build the MockGPTDataset
dataset_path (Optional[str]) – This argument is of no consequence for the MockGPTDataset
indices (numpy.ndarray) – The set of the dataset indices to expose
num_samples (int) – The number of samples to draw from the dataset
index_split (Split) – The indices Split
config (GPTDatasetConfig) – The config
- static build_low_level_dataset(
- dataset_path: str | None,
- config: GPTDatasetConfig,
Abstract method implementation
- Parameters:
dataset_path (Optional[str]) – This argument is of no consequence for the MockGPTLowLevelDataset
config (GPTDatasetConfig) – The config
- Returns:
The underlying MockGPTLowLevelDataset
- Return type:
- static numel_low_level_dataset(
- low_level_dataset: MockGPTLowLevelDataset,
Abstract method implementation
- Parameters:
low_level_dataset (MockGPTLowLevelDataset) – The underlying MockGPTLowLevelDataset
- Returns:
The number of unique elements in the underlying MockGPTLowLevelDataset
- Return type:
int
- class core.datasets.gpt_dataset.MockGPTLowLevelDataset(
- tokenizer: megatron.core.tokenizers.MegatronTokenizerBase,
Bases:
object
The mock GPT low level dataset
This class is meant to generate tokenized data in the classic “Megatron-LM” GPT style. Notably, we add the end of document token to each element indexed in __getitem__
- Parameters:
tokenizer (MegatronTokenizerBase) – The tokenizer the special token information of which
data. (we use to augment the mock)
- get(
- idx: int,
- offset: int = 0,
- length: int | None = None,
This function is n abstraction over __getitem__ with support for slicing
- Parameters:
idx (int) – The index into the dataset
offset (int) – The integer token offset in the sequence
length (Optional[int]) – The number of tokens to grab from the sequence
- Returns:
The sequence tokens at the index
- Return type:
numpy.ndarray
- max_sequence_length: int = 4096#
The hard-coded max sequence length to generate
- seed: int = 0#
The hard-coded random seed to use to set the NumPy RNG
- size: int = 100000#
The hard-coded number of samples to generate
datasets.masked_dataset module#
- class core.datasets.masked_dataset.MaskedWordPieceDataset(*args: Any, **kwargs: Any)#
Bases:
MegatronDataset
The semi-abstract base class for masked WordPiece datasets
This implementation makes the rigid assumption that all inheritor datasets are built upon the IndexedDataset class. This assumption may be pushed down to the inheritors in future if necessary.
NB: WordPiece tokenization prepends a double hash “##” to all tokens/pieces in a word, save the first token/piece.
- Parameters:
indexed_dataset (IndexedDataset) – The IndexedDataset around which to build the MegatronDataset
dataset_path (str) – The real path on disk to the dataset, for bookkeeping
indexed_indices (numpy.ndarray) – The set of the documents indices to expose
num_samples (Optional[int]) – The number of samples to draw from the indexed dataset. When None, build as many samples as correspond to one epoch.
index_split (Split) – The indexed_indices Split
config (MaskedWordPieceDatasetConfig) – The config
- static build_low_level_dataset(
- dataset_path: str,
- config: MaskedWordPieceDatasetConfig,
- static numel_low_level_dataset(
- low_level_dataset: megatron.core.datasets.indexed_dataset.IndexedDataset,
- class core.datasets.masked_dataset.MaskedWordPieceDatasetConfig(*args: Any, **kwargs: Any)#
Bases:
BlendedMegatronDatasetConfig
Configuration object for Megatron Core Masked WordPiece datasets
- masking_do_full_word: bool = None#
Whether we mask the whole word or its component parts
- masking_do_permutation: bool = None#
Whether we shuffle a subset of candidate N-grams in addition
- masking_max_ngram: int = None#
The maximum length N-gram to consider masking or permuting
- masking_probability: float = None#
The probability we mask a candidate N-gram
- masking_use_geometric_distribution: bool = None#
Whether to draw the size of the N-gram from a geometric distribution according to SpanBERT https://arxiv.org/abs/1907.10529 (Section 3.1)
- masking_use_longer_ngrams: bool = None#
Whether to favor longer N-grams over shorter N-grams
- short_sequence_probability: float = None#
The probability we return a sequence shorter than the target sequence length
datasets.bert_dataset module#
- class core.datasets.bert_dataset.BERTMaskedWordPieceDataset(*args: Any, **kwargs: Any)#
Bases:
MaskedWordPieceDataset
The BERT dataset that assumes WordPiece tokenization
- Parameters:
indexed_dataset (IndexedDataset) – The IndexedDataset around which to build the MegatronDataset
dataset_path (str) – The real path on disk to the dataset, for bookkeeping
indexed_indices (numpy.ndarray) – The set of the documents indices to expose
num_samples (Optional[int]) – The number of samples to draw from the indexed dataset. When None, build as many samples as correspond to one epoch.
index_split (Split) – The indexed_indices Split
config (BERTMaskedWordPieceDatasetConfig) – The config
datasets.t5_dataset module#
- class core.datasets.t5_dataset.T5MaskedWordPieceDataset(*args: Any, **kwargs: Any)#
Bases:
MaskedWordPieceDataset
The T5 dataset that assumes WordPiece tokenization
- Parameters:
indexed_dataset (IndexedDataset) – The IndexedDataset around which to build the MegatronDataset
dataset_path (str) – The real path on disk to the dataset, for bookkeeping
indexed_indices (numpy.ndarray) – The set of the documents indices to expose
num_samples (Optional[int]) – The number of samples to draw from the indexed dataset. When None, build as many samples as correspond to one epoch.
index_split (Split) – The indexed_indices Split
config (T5MaskedWordPieceDatasetConfig) – The config
- static config_attention_mask(
- encoder_tokens: torch.tensor,
- decoder_tokens: torch.tensor,
- encoder_mask: torch.tensor,
- decoder_mask: torch.tensor,
- use_local: bool = False,
- test_te_version: str | None = None,
Config attention-mask for encoder_mask, decoder_mask, encoder_decoder_mask conditioned on transformer-implementation (e.g. TE vs local), TE versions, and TE backends
- Parameters:
encoder_tokens (torch.tensor) – A 2-D array of tokens (bs, kv_len)
decoder_tokens (torch.tensor) – A 2-D array of tokens (bs, q_len)
encoder_mask (torch.tensor) – A 2-D array of tokens (bs, kv_len)
decoder_mask (torch.tensor) – A 2-D array of tokens (bs, q_len)
use_local (bool) – Whether the current T5 model uses local (vs TE) transformer implmentation
- Returns:
Configured encoder_mask, decoder_mask, encoder_decoder_mask torch.tensor: configured encoder attention mask torch.tensor: configured decoder attention mask torch.tensor: configured encoder-decoder attention mask
- class core.datasets.t5_dataset.T5MaskedWordPieceDatasetConfig(*args: Any, **kwargs: Any)#
Bases:
MaskedWordPieceDatasetConfig
Configuration object for Megatron Core T5 WordPiece datasets
NB: As a temporary holdover from Megatron-LM. The T5 tokenizer has an attribute which defines a number of special sentinel tokens used during sampling. The assert in __post_init__ serves to preserve compatibility with Megatron-LM until the T5 tokenizer is in Megatron Core.
- sequence_length_decoder: int = None#
The sequence length for the decoder
- sequence_length_encoder: int | None = None#
A sequence_length alias and the sequence length for the encoder
datasets.blended_dataset module#
- class core.datasets.blended_dataset.BlendedDataset(*args: Any, **kwargs: Any)#
Bases:
Dataset
Conjugating class for a set of MegatronDataset instances
- Parameters:
datasets (List[MegatronDataset]) – The MegatronDataset instances to blend
weights (List[Union[int, float]]) – The weights that determine the dataset blend ratios
size (Optional[int]) – The number of samples to draw from the blend. If None, for each dataset index idx draw exactly weights[idx] samples from datasets[idx].
config (BlendedMegatronDatasetConfig) – The config
- Raises:
RuntimeError – When the dataset has fewer or more samples than ‘size’ post-initialization
datasets.utils module#
- core.datasets.utils.compile_helpers()#
Compile C++ helper functions at runtime. Make sure this is invoked on a single process.
- core.datasets.utils.get_blend_from_list(
- blend: List[str] | None,
Get the blended_megatron_dataset_config.BlendedMegatronDatasetConfig blend from the blend list
- Parameters:
blend (Optional[List[str]]) – The blend list, which can be either (1) a list of prefixes, e.g. [“path/to/dataset_1_prefix”, “path/to/dataset_2_prefix”], or (2) a flattened, zipped list of weights and prefixes, e.g. [“30”, “path/to/dataset_1_prefix”, “70”, “path/to/dataset_2_prefix”]
- Returns:
The blend, consisting of a list of dataset prefixes and optionally a list of dataset weights, e.g. [[“path/to/dataset_1_prefix”, “path/to/dataset_2_prefix”], [30.0, 70.0]].
- Return type:
Optional[Tuple[List[str], Optional[List[float]]]]
- core.datasets.utils.normalize(weights: List[float]) List[float] #
Do non-exponentiated normalization
- Parameters:
weights (List[float]) – The weights
- Returns:
The normalized weights
- Return type:
List[float]