datasets package
Data pre-processing
Data preprocessing is built around the following classes:
IndexedDatasetBuilder
IndexedDataset
At the moment, an end-to-end data preprocessing implementation is left to the user. See the class docstring(s) for more details.
IndexedDatasetBuilder
The IndexedDatasetBuilder
is capable of building and merging IndexedDataset
instances.
IndexedDataset
The IndexedDataset
class is the lowest-level data interface in Megatron Core. Internally, an IndexedDataset
instance references two binaries: the data file (.bin
) contains document/sequence data and the index file (.idx
) contains document/sequence metadata.
The index file stores dataset-level metadata first:
The index header, for backward compatibility
The index version, for backward compatibility
A numeric code corresponding to the data type used to write data to the data file
The number of sequences in the dataset
The number of documents in the dataset
The index file stores document-level and sequence-level metadata second:
In order, the number of elements per sequence
In order, the byte offset (pointer) per sequence
In order, the consecutive sequence index range
[...)
per documentIn order, the mode per sequence (in the multimodal case)
Data loading: construction
Building the data loaders is a distributed-aware process built around the following classes:
BlendedMegatronDatasetConfig
BlendedMegatronDatasetBuilder
IndexedDataset
MegatronDataset
BlendedDataset
See the class docstrings for more details.
BlendedMegatronDatasetConfig (extendable)
The BlendedMegatronDatasetConfig
class parameterizes the BlendedMegatronDatasetBuilder
and in turn the MegatronDataset
and BlendedDataset
.
Different training/inference regimes will require different extensions e.g. the GPTDatasetConfig
BlendedMegatronDatasetBuilder
The BlendedMegatronDatasetBuilder
class builds the highest-level data interfaces in Megatron Core.
NB: All ranks should attempt to build the dataset via the BlendedMegatronDatasetBuilder
or the program will hang. Which ranks follow through on their attempts can be controlled via the BlendedMegatronDatasetConfig
.
IndexedDataset
The IndexedDataset
class is the lowest-level data interface in Megatron Core.
The IndexedDataset
should already exist on disk before attempting to build any of the high-level data interfaces.
MegatronDataset (extendable)
The MegatronDataset
abstract class is a high-level data interface in Megatron Core. It is an abstraction built upon the IndexedDataset
.
Different training/inference regimes will require different extensions e.g. the GPTDataset
BlendedDataset
The BlendedDataset
class is a high-level data interface in Megatron Core. It is an abstraction built upon the MegatronDataset
.
The BlendedDataset
is only necessary when a blend multiple data distributions, i.e. multiple MegatronDataset
instances, should contribute to a certain dataset split. The blend can be controlled via the BlendedMegatronDatasetConfig
.
Data loading: implementation
GPTDataset
The GPTDataset
is parameterized by the following variables: the underlying IndexedDataset
instance indexed_dataset
, the split indices indexed_indices
(the congituous subset of document or sequence indices used for training, validation, and testing), the number of samples N
, the sequence length S
, and the random seed R
.
The GPTDataset
creates three index mappings to facilitate lookup: (1) the document index, (2) the sample index, and (3) the shuffle index.
The document index Do_idx is a 1-D array mapping from i to document index of length
E * |indexed_indices|
whereE
corresponds to the minimum number of epochs such thatE * |indexed_indices| >= N
. The document index is shuffled according toR
.Given: N = 15 indexed_indices = [5, 6, 7, 8, 9] E = 3 Then, for example: Do_idx = [8, 8, 9, 6, 7, 5, 8, 5, 6, 6, 5, 9, 7, 7, 9]
The sample index Sa_idx is a 2-D array mapping from j to pairs of (i, Do_idx[ i ] offset) of shape
[N + 1, 2]
. The rows j and j + 1 serve as the left and right bounds for the j-th sample.Given: S = 1024 Then, for example: Sa_idx[0] = (0, 0) Sa_idx[1] = (0, 1024) => Do_idx[0] has length greater than S Sa_idx[2] = (1, 512) => Do_idx[0] has length 1536 Sa_idx[3] = (2, 0) => Do_idx[1] has length 1536 Sa_idx[4] = (5, 300) => Do_idx[2:5] are shorter documents relative to Do_idx[0:2] Sa_idx[5] = (6, 24) => Do_idx[5] has length 1300
The shuffle index Sh_idx is a 1-D array mapping from k to j of length
N
. The shuffle index is shuffled according toR
.Given N = 10 Then, for example: Sh_idx = [4, 0, 2, 6, 1, 9, 5, 8, 7, 3]
To query the GPTDataset
for the k-th sample we do the following
Use the shuffle index to get the index j into the sample index.
j = Sh_idx[k]
Use the sample index to get the left and right sample-bounding indices into the document index and the starting token offset for each document.
i, offset = Sa_idx[j] i_next, offset_next = Sa_idx[j + 1]
Use the document index to retrieve
S
tokens from consecutive (in the document index) documents.sample = [] sample += indexed_dataset[Do_idx[i]][offset:] if i != i_next: sample += indexed_dataset[Do_idx[i + 1:i_next]] sample += indexed_dataset[Do_idx[i_next]][:offset_next]
To save time during initialization, each index is built/cached sequentially on one process rank and subsequently loaded in parallel on other process ranks. The cached indices are unique to a hash generated in the MegatronDataset.__init__
function.
BlendedDataset
The BlendedDataset
is parameterized by the following variables: the underlying MegatronDataset
instances D
, the weights W
(one per dataset), and the size S
. The BlendedDataset
will draw samples from contributing datasets in proportion to the weights until achieving a composite dataset of the desired size. During each sampling step, we draw a single sample from the dataset which has the greatest sampling error.
The BlendedDataset
creates two “blending” indices to facilitate lookup: (1) the dataset index and (2) the dataset sample index.
The dataset index Da_idx is a 1-D array mapping from i to dataset index of length
S
.Given D = [d0, d1, d2] W = [1/2, 1/4, 1/4] S = 4 Then, for example: Da_idx = [0, 1, 2, 0]
The dataset sample index Sa_idx is a 1-D mapping from i to the sample index for dataset Da_idx[i] of length
S
.Given Da_idx = [0, 1, 2, 0] Then, for example: Sa_idx = [0, 0, 0, 1]
To query the BlendedDataset
for the k-th sample we do the following
Use the dataset index to retrieve the corresponding dataset from
D
and the dataset sample index to retrieve the corresponding sample from that dataset.sample = D[Da_idx[k]][Sa_idx[k]]
To save time during initialization, each index is built/cached sequentially on one process rank and subsequently loaded in parallel on other process ranks. The cached indices are unique to a hash generated in the BlendedDataset.__init__
function.
- class core.datasets.blended_megatron_dataset_config.BlendedMegatronDatasetConfig(is_built_on_rank: Callable, random_seed: int, sequence_length: int, blend: Optional[List[str]] = None, blend_per_split: Optional[List[Optional[List[str]]]] = None, split: Optional[str] = None, path_to_cache: Optional[str] = None, mmap_bin_files: bool = False, mock: bool = False, tokenizer: Optional[megatron.core.datasets.megatron_tokenizer.MegatronTokenizer] = None)
Bases:
object
Configuration object for Megatron Core datasets
- Parameters
is_built_on_rank (Callable) – A callable which returns True if the dataset should be built on the current rank. It should be Megatron Core parallelism aware i.e. global rank, group rank, and virtual rank may inform its return value.
random_seed (int) – The seed for all RNG during dataset creation.
sequence_length (int) – The sequence length.
blend (Optional[List[str]]) – The blend string, consisting of either a single dataset or a flattened sequential sequence of weight-dataset pairs. For exampe, [“dataset-path1”] and [“50”, “dataset-path1”, “50”, “dataset-path2”] are both valid. Not to be used with ‘blend_per_split’. Defaults to None.
(blend_per_split (blend_per_split) – Optional[List[Optional[List[str]]]]): A set of blend strings, as defined above, one for each split distribution. Not to be used with ‘blend’. Defauls to None.
split (Optional[str]) – The split string, a comma separated weighting for the dataset splits when drawing samples from a single distribution. Not to be used with ‘blend_per_split’. Defaults to None.
split_matrix (Optional[List[Tuple[float, float]]]) – The split matrix consisting of non-overlapping book-ends of each split in order. For more information, refer to ‘convert_split_vector_to_split_matrix’. Created automatically from ‘split’. Not to be passed in to the constructor.
path_to_cache (str) – Where all re-useable dataset indices are to be cached.
mmap_bin_files (bool) – Whether to mmap the .bin files or use file pointer.
mock (bool) – Whether to bypass real data loading and validation in favor of mock data generation.
tokenizer (Optional[MegatronTokenizer]) – The MegatronTokenizer instance or None. Required for datasets which do online tokenization.
- blend: Optional[List[str]] = None
- blend_per_split: Optional[List[Optional[List[str]]]] = None
- is_built_on_rank: Callable
- mmap_bin_files: bool = False
- mock: bool = False
- path_to_cache: Optional[str] = None
- random_seed: int
- sequence_length: int
- split: Optional[str] = None
- split_matrix: Optional[List[Tuple[float, float]]] = None
- tokenizer: Optional[megatron.core.datasets.megatron_tokenizer.MegatronTokenizer] = None
- core.datasets.blended_megatron_dataset_config.convert_split_vector_to_split_matrix(vector_a: List[float], vector_b: Optional[List[float]] = None) → List[Optional[Tuple[float, float]]]
Build the split matrix from one or optionally two contributing split vectors.
Ex. a standard conversion:
[0.99, 0.01, 0.0] -> [(0, 0.99), (0.99, 1.0), None]
Ex. a conversion for Retro when Retro pretraining uses a [0.99, 0.01, 0.0] split and Retro preprocessing used a [0.98, 0.02, 0.0] split:
[0.99, 0.01, 0.0], [0.98, 0.02, 0.0] -> [(0, 0.98), (0.99, 1.0), None]
- Parameters
vector_a (List[float]) – The primary split vector
vector_b (Optional[List[float]]) – An optional secondary split vector which constrains the primary split vector. Defaults to None.
- Returns
- Return type
The split matrix consisting of book-ends of each split in order
List[Tuple[float, float]]
- core.datasets.blended_megatron_dataset_config.parse_and_normalize_split(split: str) → List[float]
Parse the dataset split ratios from a string
- Parameters
- Returns
- Return type
split (str) – The train valid test split string e.g. “99,1,0”
The trian valid test split ratios e.g. [0.99, 0.01, 0.0]
List[float]
- class core.datasets.blended_megatron_dataset_builder.BlendedMegatronDatasetBuilder(cls: Type[Union[megatron.core.datasets.megatron_dataset.MegatronDataset, megatron.core.datasets.megatron_dataset.MockDataset]], sizes: List[int], config: megatron.core.datasets.blended_megatron_dataset_config.BlendedMegatronDatasetConfig)
Bases:
object
Builder class for the BlendedDataset and MegatronDataset classes
- Parameters
cls (Type[MegatronDataset]) – The class to instantiate, must inherit from MegatronDataset
sizes (List[int]) – The minimum number of total samples to draw from each split, varies with blend
config (BlendedMegatronDatasetConfig) – The config object which informs dataset creation
- build() → List[Optional[Union[megatron.core.datasets.blended_dataset.BlendedDataset, megatron.core.datasets.megatron_dataset.MegatronDataset, megatron.core.datasets.megatron_dataset.MockDataset]]]
Build all dataset splits according to the provided blend(s)
This method is distributed-aware and must be called on all ranks.
The dataset splits returned can vary according to the config. Supply config.blend and config.split to build BlendedDataset and/or MegatronDataset splits from the same distribution. Supply config.blend_per_split to build BlendedDataset and/or MegatronDataset splits from separate distributions.
- Returns
- Return type
A list containing a dataset instance (or None) per split
List[Optional[TopLevelDataset]]
- static build_generic_dataset(cls: Union[Type[Union[megatron.core.datasets.blended_dataset.BlendedDataset, megatron.core.datasets.megatron_dataset.MegatronDataset, megatron.core.datasets.megatron_dataset.MockDataset, megatron.core.datasets.megatron_dataset.LowLevelDataset, torch.utils.data.Dataset]], Callable], is_built_on_rank: Callable, *args: Any) → Optional[Union[megatron.core.datasets.blended_dataset.BlendedDataset, megatron.core.datasets.megatron_dataset.MegatronDataset, megatron.core.datasets.megatron_dataset.MockDataset, megatron.core.datasets.megatron_dataset.LowLevelDataset, torch.utils.data.Dataset, Iterable]]
Build the DistributedDataset
Return None if and only if the underlying dataset class is not built on the current rank and torch.distributed is initialized.
- Parameters
cls (Union[Type[DistributedDataset], Callable]) – The DistributedDataset class to be built. In special cases, e.g. when we are building the low level dataset for a RawMegatronDataset instance, we can accept a Callable which returns an Iterable.
args (Tuple[Any]) – The positional arguments used to build the provided DistributedDataset class
- Raises
- Returns
- Return type
Exception – When the dataset constructor raises an OSError
The DistributedDataset instantion, the Iterable instantiation, or None
Optional[Union[DistributedDataset, Iterable]]
- class core.datasets.megatron_tokenizer.MegatronTokenizer(*tokenizer_paths: str, **tokenizer_options: Any)
Bases:
abc.ABC
Abstract class for tokenizer
Absent a config or class-specific tracking of which objects are uniquely identifying, we must include all key word arguments as unique identifiers
- Parameters
tokenizer_paths (Tuple[str]) – All tokenizer source paths or prefixes
kwargs (Dict[str, Any]) – All tokenizer options
- property bos
The BOS token id
- Raises
NotImplementedError – Non-abstract, optional attribute
- property cls
The CLS token id
- Raises
NotImplementedError – Non-abstract, optional attribute
- detokenize(ids: numpy.ndarray) → str
Convert embedding ids to text
- Parameters
- Returns
- Return type
- Raises
ids (numpy.ndarray) – The ids to convert
The converted text
str
NotImplementedError – Non-abstract, optional method
- property eod
The EOD token id
- Raises
NotImplementedError – Non-abstract, optional attribute
- property eos
The EOS token id
- Raises
NotImplementedError – Non-abstract, optional attribute
- abstract property inv_vocab
Dictionary from vocab id token to text token
- property mask
The MASK token id
- Raises
NotImplementedError – Non-abstract, optional attribute
- property pad
The PAD token id
- Raises
NotImplementedError – Non-abstract, optional attribute
- property sep
The SEP token id
- Raises
NotImplementedError – Non-abstract, optional attribute
- abstract tokenize(text: str) → numpy.ndarray
Convert text to embedding ids
- Parameters
- Returns
- Return type
text (str) – The text to convert
The converted embedding ids
numpy.ndarray
- abstract property vocab
Dictionary from vocab text token to id token
- abstract property vocab_size
The vocabulary size
- class core.datasets.indexed_dataset.DType(value)
Bases:
enum.Enum
The NumPy data type Enum for writing/reading the IndexedDataset indices
- classmethod code_from_dtype(value: Type[numpy.number]) → int
Get the code from the dtype
- Parameters
- Returns
- Return type
value (Type[numpy.number]) – The dtype
The code
int
- classmethod dtype_from_code(value: int) → Type[numpy.number]
Get the dtype from the code
- Parameters
- Returns
- Return type
value (int) – The code
The dtype
Type[numpy.number]
- float32 = 7
- float64 = 6
- int16 = 3
- int32 = 4
- int64 = 5
- int8 = 2
- static optimal_dtype(cardinality: Optional[int]) → Type[numpy.number]
Get the dtype to use for an index of a certain cardinality
- Parameters
- Returns
- Return type
cardinality (Optional[int]) – The number of elements to be indexed
The dtype to use for the index
Type[numpy.number]
- static size(key: Union[int, Type[numpy.number]]) → int
Get the size of the dtype/code in bytes
- Parameters
- Raises
- Returns
- Return type
key (Union[int, Type[numpy.number]]) – The dtype or code
ValueError – If the key is neither dtype nor integer code
The size of the dtype/code in in bytes
int
- uint16 = 8
- uint8 = 1
- class core.datasets.indexed_dataset.IndexedDataset(*args: Any, **kwargs: Any)
Bases:
torch.utils.data.Dataset
The low-level interface dataset class
- Parameters
path_prefix (str) – The index (.idx) and data (.bin) prefix
multimodal (bool, optional) – Whether the dataset is multimodal. Defaults to False.
mmap (bool, optional) – Whether to mmap the .bin files. Defaults to True.
- property document_indices: numpy.ndarray
Get the document indices
- Returns
- Return type
The document indices
numpy.ndarray
- static exists(path_prefix: str) → bool
Return whether the IndexedDataset exists on disk at the prefix
- Parameters
- Returns
- Return type
path_prefix (str) – The prefix to the index (.idx) and data (.bin) files
Whether the IndexedDataset exists on disk at the prefix
bool
- get(idx: int, offset: int = 0, length: Optional[int] = None) → numpy.ndarray
Retrieve a single item from the dataset with the option to only return a portion of the item.
get(idx) is the same as [idx] but get() does not support slicing.
- Parameters
idx (Union[int, numpy.integer]) – The index into the dataset
offset (int) – The integer token offset in the sequence
length (int) – The number of tokens to grab from the sequence
- Returns
- Return type
The sequence tokens and modes at the index
Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]
- get_document_indices() → numpy.ndarray
Get the document indices
This method is slated for deprecation.
- Returns
- Return type
The document indices
numpy.ndarray
- initialize(path_prefix: str, multimodal: bool, mmap: bool) → None
Initialize the dataset
This method is called by IndexedDataset.__init__ during object creation and by IndexedDataset.__setstate__ during un-puckling
- Parameters
path_prefix (str) – The index (.idx) and data (.bin) prefix
multimodal (bool) – Whether the dataset is multimodal
mmap (bool) – Whether to mmap the .bin file
- property sequence_lengths: numpy.ndarray
Get the sequence lengths
- Returns
- Return type
The sequence lengths
numpy.ndarray
- property sequence_modes: numpy.ndarray
Get the sequence modes
- Returns
- Return type
The sequence modes
numpy.ndarray
- set_document_indices(document_indices: numpy.ndarray) → None
Set the document indices
This method is slated for deprecation.
- Parameters
document_indices (numpy.ndarray) – The document indices
- class core.datasets.indexed_dataset.IndexedDatasetBuilder(bin_path: str, dtype: Type[numpy.number] = numpy.int32, multimodal: bool = False)
Bases:
object
Builder class for the IndexedDataset class
- Parameters
bin_path (str) – The path to the data (.bin) file
dtype (Type[numpy.number], optional) – The dtype of the index file. Defaults to numpy.int32.
multimodal (bool, optional) – Whether the dataset is multimodal. Defaults to False.
- add_document(tensor: torch.Tensor, lengths: List[int], modes: Optional[List[int]] = None) → None
Add an entire document to the dataset
- Parameters
tensor (torch.Tensor) – The document to add
lengths (List[int]) – The lengths of each item in the document
modes (Optional[List[int]], optional) – The modes for each item in the document. Defaults to None.
- add_index(path_prefix: str) → None
Add an entire IndexedDataset to the dataset
- Parameters
path_prefix (str) – The index (.idx) and data (.bin) prefix
- add_item(tensor: torch.Tensor, mode: int = 0) → None
Add a single item to the dataset
- Parameters
tensor (torch.Tensor) – The item to add to the data file
mode (int, optional) – The mode for the item. Defaults to 0.
- end_document() → None
Finalize the document, for use with IndexedDatasetBuilder.add_item
- finalize(idx_path: str) → None
Clean up and write the index (.idx) file
- Parameters
idx_path (str) – The path to the index file
- core.datasets.indexed_dataset.get_bin_path(path_prefix: str) → str
Get the path to the data file from the prefix
- Parameters
- Returns
- Return type
path_prefix (str) – The prefix
The path to the data file
str
- core.datasets.indexed_dataset.get_idx_path(path_prefix: str) → str
Get the path to the index file from the prefix
- Parameters
- Returns
- Return type
path_prefix (str) – The prefix
The path to the index file
str
- class core.datasets.megatron_dataset.MegatronDataset(*args: Any, **kwargs: Any)
Bases:
abc.ABC
,torch.utils.data.Dataset
The highest level wrapper class from which all dataset classes should inherit
- Parameters
dataset (LowLevelDataset) – The dataset around which to build the MegatronDataset
dataset_path (str) – The real path on disk to the dataset, for bookkeeping. TODO: subsume this argument by enforcing auto-bookkeeping in the dataset class type.
indices (numpy.ndarray) – The set of the documents indices to expose
num_samples (int) – The number of samples to draw from the indexed dataset
index_split (Split) – The indices Split
config (BlendedMegatronDatasetConfig) – The config
- static build_low_level_dataset(dataset_path: str, config: megatron.core.datasets.blended_megatron_dataset_config.BlendedMegatronDatasetConfig) → Union[megatron.core.datasets.indexed_dataset.IndexedDataset, Iterable]
Build the low level dataset via a function to be called from within BlendedMegatronDatasetBuilder.build_generic_dataset
It may be that the low level dataset spans any subset of train/valid/test splits, which is why we define a static “build” function separately from the constructor in the mid level dataset class
- Parameters
dataset_path (str) – The real path on disk to the dataset
config (BlendedMegatronDatasetConfig) – The dataset config
- Returns
- Return type
The low level dataset
LowLevelDataset
- static numel_low_level_dataset(low_level_dataset: Union[megatron.core.datasets.indexed_dataset.IndexedDataset, Iterable]) → int
Return the number of elements in the underlying low level dataset for the purpose of segregating the train/valid/test split indices
It may be that the low level dataset can be split any number of ways, depending on the mid level dataset it supports, which is why we define the “number of elements” function separately from the __len__ function here in the mid level dataset class
- Parameters
- Returns
- Return type
low_level_dataset (LowLevelDataset) – The underlying low level dataset
The number of elements in the underlying low level dataset
int
- class core.datasets.megatron_dataset.MockDataset(*args: Any, **kwargs: Any)
Bases:
core.datasets.megatron_dataset.MegatronDataset
The highest level wrapper class from which all mock dataset classes should inherit
The MockDataset is a special, one-off class that should not serve as a precedent for developers seeking to extend the MegatronDataset. This class is incompatible with BlendedDataset
This class cannibalizes the constructor of the parent class. As such, we do not need to pass in some constructor parameters. They may be populated, but most are superfluous and can be None. Only num_samples, index_split, and config are required.
- Parameters
dataset (Optional[LowLevelDataset]) – The dataset around which to build the MegatronDataset
dataset_path (Optional[str]) – The real path on disk to the dataset, for bookkeeping. TODO: subsume
type. (this argument by enforcing auto-bookkeeping in the dataset class) –
indices (Optional[numpy.ndarray]) – The set of the documents indices to expose
num_samples (int) – The number of samples to draw from the indexed dataset
index_split (Split) – The indices Split
config (BlendedMegatronDatasetConfig) – The config
- class core.datasets.gpt_dataset.GPTDataset(*args: Any, **kwargs: Any)
Bases:
megatron.core.datasets.megatron_dataset.MegatronDataset
The base GPT dataset
- Parameters
indexed_dataset (IndexedDataset) – The IndexedDataset around which to build the MegatronDataset
dataset_path (str) – The real path on disk to the dataset, for bookkeeping
indexed_indices (numpy.ndarray) – The set of the documents indices to expose
num_samples (int) – The number of samples to draw from the indexed dataset
index_split (Split) – The indexed_indices Split
config (GPTDatasetConfig) – The config
- static build_low_level_dataset(dataset_path: str, config: core.datasets.gpt_dataset.GPTDatasetConfig) → megatron.core.datasets.indexed_dataset.IndexedDataset
Abstract method implementation
- Parameters
dataset_path (str) – The real path prefix to the IndexedDataset .bin and .idx files
config (BlendedMegatronDatasetConfig) – The dataset config
- Returns
- Return type
The underlying IndexedDataset
- static numel_low_level_dataset(low_level_dataset: megatron.core.datasets.indexed_dataset.IndexedDataset) → int
Abstract method implementation
For GPT, the underlying IndexedDataset should be split by sequence, as opposed to, say, BERT, which should be split by document
- Parameters
- Returns
- Return type
low_level_dataset (IndexedDataset) – The underlying IndexedDataset
The number of unique elements in the underlying IndexedDataset
int
- class core.datasets.gpt_dataset.GPTDatasetConfig(*args: Any, **kwargs: Any)
Bases:
megatron.core.datasets.blended_megatron_dataset_config.BlendedMegatronDatasetConfig
Configuration object for Megatron Core GPT datasets
- Parameters
reset_position_ids (bool) – Option to reset the position IDs in the dataset at an interval
reset_attention_mask (bool) – Option to reset the attention mask from the dataset
eod_mask_loss (bool) – Option to enable the EOD mask loss
vocab_size (int) – Size of vocabulary
- eod_mask_loss: bool = None
- reset_attention_mask: bool = None
- reset_position_ids: bool = None
- vocab_size: int = 9223372036854775807
- class core.datasets.gpt_dataset.MockGPTDataset(*args: Any, **kwargs: Any)
Bases: megatron.core.datasets.megatron_dataset.MockDataset
The mock GPT dataset
- class core.datasets.masked_dataset.MaskedWordPieceDataset(*args: Any, **kwargs: Any)
Bases:
megatron.core.datasets.megatron_dataset.MegatronDataset
The semi-abstract base class for masked WordPiece datasets
This implementation makes the rigid assumption that all inheritor datasets are built upon the IndexedDataset class. This assumption may be pushed down to the inheritors in future if necessary.
NB: WordPiece tokenization prepends a double hash “##” to all tokens/pieces in a word, save the first token/piece.
- Parameters
indexed_dataset (IndexedDataset) – The IndexedDataset around which to build the MegatronDataset
dataset_path (str) – The real path on disk to the dataset, for bookkeeping
indexed_indices (numpy.ndarray) – The set of the documents indices to expose
num_samples (int) – The number of samples to draw from the indexed dataset
index_split (Split) – The indexed_indices Split
config (MaskedWordPieceDatasetConfig) – The config
- static build_low_level_dataset(dataset_path: str, config: core.datasets.masked_dataset.MaskedWordPieceDatasetConfig) → megatron.core.datasets.indexed_dataset.IndexedDataset
- static numel_low_level_dataset(low_level_dataset: megatron.core.datasets.indexed_dataset.IndexedDataset) → int
- class core.datasets.masked_dataset.MaskedWordPieceDatasetConfig(*args: Any, **kwargs: Any)
Bases:
megatron.core.datasets.blended_megatron_dataset_config.BlendedMegatronDatasetConfig
Configuration object for Megatron Core Masked WordPiece datasets
- Parameters
masking_probability (float) – The probability we mask a candidate N-gram
short_sequence_probability (float) – The probability we return a sequence shorter than the target sequence length
masking_max_ngram (int) – The maximum length N-gram to consider masking or permuting
masking_do_full_word (bool) – Whether we mask the the whole word or its component parts
masking_do_permutation (bool) – Whether we shuffle a subset of candidate N-grams in addition to masking
masking_use_longer_ngrams (bool) – Wehther to favor longer N-grams over shorter N-grams
masking_use_geometric_distribution (bool) – Whether to draw the size of the N-gram from a geometric distribution according to SpanBERT https://arxiv.org/abs/1907.10529 (Section 3.1)
- masking_do_full_word: bool = None
- masking_do_permutation: bool = None
- masking_max_ngram: int = None
- masking_probability: float = None
- masking_use_geometric_distribution: bool = None
- masking_use_longer_ngrams: bool = None
- short_sequence_probability: float = None
- class core.datasets.bert_dataset.BERTMaskedWordPieceDataset(*args: Any, **kwargs: Any)
Bases:
megatron.core.datasets.masked_dataset.MaskedWordPieceDataset
The BERT dataset that assumes WordPiece tokenization
- Parameters
indexed_dataset (IndexedDataset) – The IndexedDataset around which to build the MegatronDataset
dataset_path (str) – The real path on disk to the dataset, for bookkeeping
indexed_indices (numpy.ndarray) – The set of the documents indices to expose
num_samples (int) – The number of samples to draw from the indexed dataset
index_split (Split) – The indexed_indices Split
config (BERTMaskedWordPieceDatasetConfig) – The config
- class core.datasets.bert_dataset.BERTMaskedWordPieceDatasetConfig(*args: Any, **kwargs: Any)
Bases:
megatron.core.datasets.masked_dataset.MaskedWordPieceDatasetConfig
Configuration object for Megatron Core BERT WordPiece datasets
- Parameters
classification_head (bool) – Option to perform the next sequence prediction during sampling
- classification_head: bool = None
- class core.datasets.t5_dataset.T5MaskedWordPieceDataset(*args: Any, **kwargs: Any)
Bases:
megatron.core.datasets.masked_dataset.MaskedWordPieceDataset
The T5 dataset that assumes WordPiece tokenization
- Parameters
indexed_dataset (IndexedDataset) – The IndexedDataset around which to build the MegatronDataset
dataset_path (str) – The real path on disk to the dataset, for bookkeeping
indexed_indices (numpy.ndarray) – The set of the documents indices to expose
num_samples (int) – The number of samples to draw from the indexed dataset
index_split (Split) – The indexed_indices Split
config (T5MaskedWordPieceDatasetConfig) – The config
- class core.datasets.t5_dataset.T5MaskedWordPieceDatasetConfig(*args: Any, **kwargs: Any)
Bases:
megatron.core.datasets.masked_dataset.MaskedWordPieceDatasetConfig
Configuration object for Megatron Core T5 WordPiece datasets
NB: As a temporary holdover from Megatron-LM. The T5 tokenizer has an attribute which defines a number of special sentinel tokens used during sampling. The assert in __post_init__ serves to preserve compatibility with Megatron-LM until the T5 tokenizer is in Megatron Core.
- Parameters
sequence_length_encoder (Optional[int]) – A sequence_length alias and the sequence length for the encoder
sequence_length_decoder (int) – The sequence length for the decoder
- sequence_length_decoder: int = None
- sequence_length_encoder: Optional[int] = None
- class core.datasets.blended_dataset.BlendedDataset(*args: Any, **kwargs: Any)
Bases:
torch.utils.data.Dataset
Conjugating class for a set of MegatronDataset instances
- Parameters
datasets (List[MegatronDataset]) – The MegatronDataset instances to blend
weights (List[float]) – The weights which determines the dataset blend ratios
size (int) – The number of samples to draw from the blend
config (BlendedMegatronDatasetConfig) – The config
- Raises
RuntimeError – When the dataset has fewer or more samples than ‘size’ post-initialization
- class core.datasets.utils.Split(value)
Bases:
enum.Enum
An enumeration.
- test = 2
- train = 0
- valid = 1
- core.datasets.utils.compile_helpers()
Compile C++ helper functions at runtime. Make sure this is invoked on a single process.
- core.datasets.utils.log_single_rank(logger: logging.Logger, *args: Any, rank: int = 0, **kwargs: Any)
If torch distributed is initialized, log only on rank
- Parameters
logger (logging.Logger) – The logger to write the logs
args (Tuple[Any]) – All logging.Logger.log positional arguments
rank (int, optional) – The rank to write on. Defaults to 0.
kwargs (Dict[str, Any]) – All logging.Logger.log keyword arguments
- core.datasets.utils.normalize(weights: List[float]) → List[float]
Do non-exponentiated normalization
- Parameters
- Returns
- Return type
weights (List[float]) – The weights
The normalized weights
List[float]