core.datasets.gpt_dataset#

Module Contents#

Classes#

GPTDatasetConfig

Configuration object for Megatron Core GPT datasets

GPTDataset

The base GPT dataset

MockGPTLowLevelDataset

The mock GPT low level dataset

MockGPTDataset

The mock GPT dataset

Functions#

_build_document_index

Build an array with length = num epochs * num documents

_build_shuffle_index

Build the range [0, size) and shuffle

_get_ltor_masks_and_position_ids

Build masks and position id for left to right model.

Data#

API#

core.datasets.gpt_dataset.logger#

‘getLogger(…)’

class core.datasets.gpt_dataset.GPTDatasetConfig#

Bases: megatron.core.datasets.blended_megatron_dataset_config.BlendedMegatronDatasetConfig

Configuration object for Megatron Core GPT datasets

reset_position_ids: Optional[bool]#

None

Option to reset the position IDs in the dataset at an interval

reset_attention_mask: Optional[bool]#

None

Option to reset the attention mask from the dataset

eod_mask_loss: Optional[bool]#

None

Option to enable the EOD mask loss

create_attention_mask: bool#

True

Option to enable the attention masks generation. Can be disabled if attention kernel generates masks by itself.

drop_last_partial_validation_sequence: bool#

True

Option to drop the last partial validation sequence

add_extra_token_to_sequence: bool#

True

Option to draw sequences with one extra token to ensure the sample input tokens and sample output tokens are both of the desired sequence length

object_storage_cache_path: Optional[str]#

None

Path for caching indices for s3 or msc dataloading.

__post_init__() None#

Do asserts and set fields post init

class core.datasets.gpt_dataset.GPTDataset(
indexed_dataset: megatron.core.datasets.indexed_dataset.IndexedDataset,
dataset_path: Optional[str],
indexed_indices: numpy.ndarray,
num_samples: Optional[int],
index_split: megatron.core.datasets.utils.Split,
config: core.datasets.gpt_dataset.GPTDatasetConfig,
)#

Bases: megatron.core.datasets.megatron_dataset.MegatronDataset

The base GPT dataset

Parameters:
  • indexed_dataset (IndexedDataset) – The IndexedDataset around which to build the GPTDataset

  • dataset_path (Optional[str]) – The real path on disk to the dataset, for bookkeeping

  • indexed_indices (numpy.ndarray) – The set of the documents indices to expose

  • num_samples (Optional[int]) – The number of samples to draw from the indexed dataset. When None, build as many samples as correspond to one epoch.

  • index_split (Split) – The indexed_indices Split

  • config (GPTDatasetConfig) – The config

Initialization

static numel_low_level_dataset(
low_level_dataset: megatron.core.datasets.indexed_dataset.IndexedDataset,
) int#

Abstract method implementation

For GPT, the underlying IndexedDataset should be split by sequence, as opposed to, say, BERT, which should be split by document

Parameters:

low_level_dataset (IndexedDataset) – The underlying IndexedDataset

Returns:

The number of unique elements in the underlying IndexedDataset

Return type:

int

static build_low_level_dataset(
dataset_path: str,
config: core.datasets.gpt_dataset.GPTDatasetConfig,
) megatron.core.datasets.indexed_dataset.IndexedDataset#

Abstract method implementation

Parameters:
  • dataset_path (str) – The real path prefix to the IndexedDataset .bin and .idx files

  • config (GPTDatasetConfig) – The config

Returns:

The underlying IndexedDataset

Return type:

IndexedDataset

__len__() int#

Abstract method implementation

Returns:

The length of the dataset

Return type:

int

__getitem__(
idx: Optional[int],
) Dict[str, torch.Tensor]#

Abstract method implementation

Parameters:

idx (Optional[int]) – The index into the dataset

Returns:

The sample information wrapped in a dictionary

Return type:

Dict[str, torch.Tensor]

_query_document_sample_shuffle_indices(
idx: int,
) Tuple[numpy.ndarray, numpy.ndarray]#

Get the text (token ids) and document ids for a given index

Parameters:

idx (int) – The index into the dataset

Returns:

The text ids and document ids

Return type:

Tuple[numpy.ndarray, numpy.ndarray]

_build_document_sample_shuffle_indices() Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]#

Build the document index, the sample index, and the shuffle index

The document index: – 1-D – An ordered array of document ids

The sample index: – 2-D – The document indices and offsets which mark the start of every sample

The shuffle index: – 1-D – A random permutation of index range of the sample index

Returns:

The document index, the sample index, and the shuffle index

Return type:

Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]

_get_num_tokens_per_epoch() int#

Calculate the number of tokens in a single epoch

Returns:

The number of tokens in a single epoch

Return type:

int

_get_num_epochs(num_tokens_per_epoch: int) int#

Calculate the number of epochs

Parameters:

num_tokens_per_epoch (int) – The number of tokens in a single epoch

Returns:

The number of epochs

Return type:

int

core.datasets.gpt_dataset._build_document_index(
documents: numpy.ndarray,
num_epochs: int,
numpy_random_state: numpy.random.RandomState,
separate_final_epoch: bool,
) numpy.ndarray#

Build an array with length = num epochs * num documents

Parameters:
  • documents (numpy.ndarray) – the subset of exposed document indices

  • num_epochs (int) – The number of epochs

  • numpy_random_state (numpy.random.RandomState) – The NumPy random state

  • separate_final_epoch (bool) – Whether to exclude the last epoch from the global shuffle

Returns:

The document index

Return type:

numpy.ndarray

core.datasets.gpt_dataset._build_shuffle_index(
num_samples: int,
total_size: int,
numpy_random_state: numpy.random.RandomState,
) numpy.ndarray#

Build the range [0, size) and shuffle

Parameters:
  • num_samples (int) – The size of the first shuffle range [0, num_samples)

  • total_size (int) – The size of the entire index. If larger than ‘num_samples’, it defines the second shuffle range [num_samples, total_size)

  • numpy_random_state (numpy.random.RandomState) – The NumPy random state

Returns:

The shuffle index

Return type:

numpy.ndarray

core.datasets.gpt_dataset._get_ltor_masks_and_position_ids(
data: torch.Tensor,
eod_token: int,
reset_position_ids: bool,
reset_attention_mask: bool,
eod_mask_loss: bool,
create_attention_mask: bool,
)#

Build masks and position id for left to right model.

Parameters:
  • data (torch.Tensor) – The data tenor that holds the tokens from the dataset

  • eod_token (int) – ID of the token to that is considered the EOD

  • reset_position_ids (bool) – Switch to reset the document position ID’s

  • reset_attention_mask (bool) – Switch to reset the attention mask

  • eod_mask_loss (bool) – Switch to enable the EOD mask loss

  • create_attention_mask (bool) – Switch to enable the attention masks generation. Can be disabled if attention kernel generates masks by itself.

Returns:

Attention mask needed to be used for Attention

torch.Tensor: The mask used for loss value during training

torch.Tensor: The position ID’s of the token

Return type:

torch.Tensor

class core.datasets.gpt_dataset.MockGPTLowLevelDataset(
tokenizer: megatron.core.tokenizers.MegatronTokenizerBase,
)#

The mock GPT low level dataset

This class is meant to generate tokenized data in the classic “Megatron-LM” GPT style. Notably, we add the end of document token to each element indexed in getitem

Parameters:
  • tokenizer (MegatronTokenizerBase) – The tokenizer the special token information of which

  • data. (we use to augment the mock)

Initialization

seed: int#

0

The hard-coded random seed to use to set the NumPy RNG

size: int#

100000

The hard-coded number of samples to generate

max_sequence_length: int#

4096

The hard-coded max sequence length to generate

__len__() int#
__getitem__(idx: int) numpy.number#
get(
idx: int,
offset: int = 0,
length: Optional[int] = None,
) numpy.ndarray#

This function is an abstraction over getitem with support for slicing

Parameters:
  • idx (int) – The index into the dataset

  • offset (int) – The integer token offset in the sequence

  • length (Optional[int]) – The number of tokens to grab from the sequence

Returns:

The sequence tokens at the index

Return type:

numpy.ndarray

class core.datasets.gpt_dataset.MockGPTDataset(
dataset: core.datasets.gpt_dataset.MockGPTLowLevelDataset,
dataset_path: Optional[str],
indices: numpy.ndarray,
num_samples: int,
index_split: megatron.core.datasets.utils.Split,
config: core.datasets.gpt_dataset.GPTDatasetConfig,
)#

Bases: core.datasets.gpt_dataset.GPTDataset

The mock GPT dataset

Parameters:
  • dataset (MockGPTLowLevelDataset) – The MockGPTLowLevelDataset around which to build the MockGPTDataset

  • dataset_path (Optional[str]) – This argument is of no consequence for the MockGPTDataset

  • indices (numpy.ndarray) – The set of the dataset indices to expose

  • num_samples (int) – The number of samples to draw from the dataset

  • index_split (Split) – The indices Split

  • config (GPTDatasetConfig) – The config

Initialization

static numel_low_level_dataset(
low_level_dataset: core.datasets.gpt_dataset.MockGPTLowLevelDataset,
) int#

Abstract method implementation

Parameters:

low_level_dataset (MockGPTLowLevelDataset) – The underlying MockGPTLowLevelDataset

Returns:

The number of unique elements in the underlying MockGPTLowLevelDataset

Return type:

int

static build_low_level_dataset(
dataset_path: Optional[str],
config: core.datasets.gpt_dataset.GPTDatasetConfig,
) core.datasets.gpt_dataset.MockGPTLowLevelDataset#

Abstract method implementation

Parameters:
  • dataset_path (Optional[str]) – This argument is of no consequence for the MockGPTLowLevelDataset

  • config (GPTDatasetConfig) – The config

Returns:

The underlying MockGPTLowLevelDataset

Return type:

MockGPTLowLevelDataset