core.datasets.t5_dataset#
Module Contents#
Classes#
Configuration object for Megatron Core T5 WordPiece datasets |
|
The T5 dataset that assumes WordPiece tokenization |
API#
- class core.datasets.t5_dataset.T5MaskedWordPieceDatasetConfig#
Bases:
megatron.core.datasets.masked_dataset.MaskedWordPieceDatasetConfigConfiguration object for Megatron Core T5 WordPiece datasets
NB: As a temporary holdover from Megatron-LM. The T5 tokenizer has an attribute which defines a number of special sentinel tokens used during sampling. The assert in post_init serves to preserve compatibility with Megatron-LM until the T5 tokenizer is in Megatron Core.
- sequence_length_encoder: Optional[int]#
‘field(…)’
A sequence_length alias and the sequence length for the encoder
- sequence_length_decoder: int#
None
The sequence length for the decoder
- __post_init__() None#
Do asserts and set fields post init
- class core.datasets.t5_dataset.T5MaskedWordPieceDataset(
- indexed_dataset: megatron.core.datasets.indexed_dataset.IndexedDataset,
- dataset_path: str,
- indexed_indices: numpy.ndarray,
- num_samples: Optional[int],
- index_split: megatron.core.datasets.utils.Split,
- config: core.datasets.t5_dataset.T5MaskedWordPieceDatasetConfig,
Bases:
megatron.core.datasets.masked_dataset.MaskedWordPieceDatasetThe T5 dataset that assumes WordPiece tokenization
- Parameters:
indexed_dataset (IndexedDataset) – The IndexedDataset around which to build the MegatronDataset
dataset_path (str) – The real path on disk to the dataset, for bookkeeping
indexed_indices (numpy.ndarray) – The set of the documents indices to expose
num_samples (Optional[int]) – The number of samples to draw from the indexed dataset. When None, build as many samples as correspond to one epoch.
index_split (Split) – The indexed_indices Split
config (T5MaskedWordPieceDatasetConfig) – The config
Initialization
- static _key_config_attributes() List[str]#
Inherited method implementation
- Returns:
The key config attributes
- Return type:
List[str]
- static _build_b1ss_attention_mask(
- source_block: torch.tensor,
- target_block: torch.tensor,
- make_history_mask: bool = False,
Build an attention-mask having shape (bs, 1, q_len, kv_len) from source_block and target_block
- Parameters:
source_block (torch.tensor) – A 2-D array of tokens (bs, q_len)
target_block (torch.tensor) – A 2-D array of tokens (bs, kv_len)
make_history_mask (bool) – Whether to turn mask into causal mask
- Returns:
The 4-D attention mask (bs, 1, q_len, kv_len)
- Return type:
torch.tensor
- static config_attention_mask(
- encoder_tokens: torch.tensor,
- decoder_tokens: torch.tensor,
- encoder_mask: torch.tensor,
- decoder_mask: torch.tensor,
- use_local: bool = False,
- test_te_version: str = None,
Config attention-mask for encoder_mask, decoder_mask, encoder_decoder_mask conditioned on transformer-implementation (e.g. TE vs local), TE versions, and TE backends
- Parameters:
encoder_tokens (torch.tensor) – A 2-D array of tokens (bs, kv_len)
decoder_tokens (torch.tensor) – A 2-D array of tokens (bs, q_len)
encoder_mask (torch.tensor) – A 2-D array of tokens (bs, kv_len)
decoder_mask (torch.tensor) – A 2-D array of tokens (bs, q_len)
use_local (bool) – Whether the current T5 model uses local (vs TE) transformer implmentation
test_te_version (str) – The Transformer Engine version to test against. Defaults to None.
- Returns:
Configured encoder_mask, decoder_mask, encoder_decoder_mask torch.tensor: configured encoder attention mask torch.tensor: configured decoder attention mask torch.tensor: configured encoder-decoder attention mask
- __getitem__(
- idx: int,
Abstract method implementation
- Parameters:
idx (int) – The index into the dataset
- Returns:
The sample data including encoder input, decoder input/output, and masks.
- Return type:
Dict[str, Union[int, numpy.ndarray]]
- _get_token_mask(numpy_random_state: numpy.random.RandomState) int#
Abstract method implementation
100% of the time, replace the token id with mask token id.
- Parameters:
numpy_random_state (RandomState) – The NumPy random state
- Returns:
The mask token id
- Return type:
int