# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import torch
from datasets import Dataset
from torch.nn import functional as F
from tqdm import tqdm
logger = logging.getLogger(__name__)
CROSS_ENTROPY_IGNORE_IDX = -100
PACK_TYPE = dict[str, torch.Tensor | list[int]]
# based on https://github.com/pytorch/torchtune/blob/v0.6.1/torchtune/datasets/_packed.py#L17
[docs]
class PackedSequence:
"""
Implements Packed Sequence for input dataset.
Args:
dataset: Actual dataset (can be 'train', 'val' or 'test')
split (str): Whether the dataset is 'train', 'val' or 'test'
packed_sequence_size (int): Number of tokens in a pack
split_across_pack (bool): If the last sample in a pack does not fit in
``packed_sequence_size``, split the sample into the next pack, or move it entirely
to the beginning of the next pack. Default: False
max_packs (int): Maximum number of packs. Default: None
"""
def __init__(self, dataset, split, packed_sequence_size, split_across_pack=False, max_packs=None):
"""
Packed Sequence constructor.
Given the dataset and the rest of the arguments, it will create (using the .pack) method
another dataset containing packed sequences.
Args:
dataset: Actual dataset (can be 'train', 'val' or 'test')
split (str): Whether the dataset is 'train', 'val' or 'test'
packed_sequence_size (int): Number of tokens in a pack
split_across_pack (bool): If the last sample in a pack does not fit in
``packed_sequence_size``, split the sample into the next pack, or move it entirely
to the beginning of the next pack. Default: False
max_packs (int): Maximum number of packs. Default: None
"""
self.dataset = dataset
self.split = split
self.padding_idx = 0 # Padding value to pack a sequence to self.packed_sequence_size
self.contains_loss_mask = False
self.packed_sequence_size = packed_sequence_size
self.split_across_pack = split_across_pack
self.max_packs = max_packs
self.packs: list[PACK_TYPE] = []
[docs]
def pack(self):
"""
Pack the dataset to defined length.
In particulat, it will iterate through the dataset. Use a buffer to hold samples until
packed_sequence_size, then append the buffer to self.packs as a single "packed" sample.
Continue until max_packs or end of dataset.
"""
# Only show progress bar on rank 0
rank = (
torch.distributed.get_rank()
if torch.distributed.is_available() and torch.distributed.is_initialized()
else 0
)
if "loss_mask" in self.dataset[0]:
self.contains_loss_mask = True
# Buffer to hold samples until they are long enough to be added to self.packs
current_pack = {
"input_ids": [],
"labels": [],
"position_ids": [],
"seq_lens": [],
}
if self.contains_loss_mask:
current_pack["loss_mask"] = []
self.previous_sample_boundary: int = 0
if rank == 0:
pbar = tqdm(total=len(self.dataset), desc=f"Packing {self.split} dataset", dynamic_ncols=True)
for sample in self.dataset:
input_ids, labels = sample["input_ids"], sample["labels"]
if self.contains_loss_mask:
loss_mask = sample["loss_mask"]
# If the dataset outputs samples that are larger than the specified
# packed_sequence_size and we're unable to split it, user needs to modify
# one of the two parameters
seq_len = len(input_ids)
if seq_len > self.packed_sequence_size and not self.split_across_pack:
raise ValueError(
f"Dataset sample is too long ({seq_len} > {self.packed_sequence_size}). "
"Please set `split_across_pack=True` or increase `packed_sequence_size`.",
)
# Update the current pack
# "position_ids" is the pos ids, "seq_lens" is the len of each seq within the pack
current_pack["input_ids"] += input_ids
current_pack["labels"] += labels
current_pack["position_ids"] += [x % self.packed_sequence_size for x in range(seq_len)]
current_pack["seq_lens"] += [seq_len]
if self.contains_loss_mask:
current_pack["loss_mask"] += loss_mask
# If the current pack is over the packed_sequence_size, add it to self.packs and
# retain any truncated or bumped samples for next pack
while len(current_pack["input_ids"]) > self.packed_sequence_size and not self._should_stop_packing():
current_pack = self._split_and_add_pack(current_pack)
if rank == 0:
pbar.update()
# Keep track of previous sample boundary
self.previous_sample_boundary = len(current_pack["input_ids"])
if self._should_stop_packing():
break
# Handle the last pack if there's leftover and we haven't filled up the max packs
if len(current_pack["input_ids"]) > 0 and (self.max_packs is None or len(self.packs) < self.max_packs):
# No need to handle splitting at this point so we can just add the current pack
self._add_pack(current_pack)
# After packing all samples, convert self.packs to a Dataset object
self.packed_dataset = Dataset.from_dict({key: [pack[key] for pack in self.packs]
for key in self.packs[0].keys()})
logger.info(f">>>>> Total number of packs created: {len(self.packs)} <<<<<")
return self.packed_dataset
[docs]
def _should_stop_packing(self) -> bool:
"""
If max packs is set, stop packing when we reach that number.
"""
if self.max_packs is not None and len(self.packs) == self.max_packs:
return True
return False
[docs]
def _split_and_add_pack(self, current_pack: PACK_TYPE) -> PACK_TYPE:
"""
Splits the current pack at the boundary, processes it, adds it to ``self.packs``.
...and returns the start of the next pack.
TODO(@akoumparouli): refactor.
"""
if self.split_across_pack:
boundary = self.packed_sequence_size
# The last elem in ``seq_lens`` ensures that ``sum(seq_lens) == self.packed_sequence_size``
leftover_seq_len = self.packed_sequence_size - sum(current_pack["seq_lens"][:-1])
seq_len_padding = [leftover_seq_len] if leftover_seq_len > 0 else []
else:
boundary = self.previous_sample_boundary
# If we aren't splitting across packs, we leave out the last sample b/c
# it will go into the next pack
seq_len_padding = []
pack = {
"input_ids": current_pack["input_ids"][:boundary],
"labels": current_pack["labels"][:boundary],
"position_ids": current_pack["position_ids"][:boundary],
"seq_lens": current_pack["seq_lens"][:-1] + seq_len_padding,
}
if self.contains_loss_mask:
pack["loss_mask"] = current_pack["loss_mask"][:boundary]
# Process and add the pack
self._add_pack(pack)
# Return the length of the first sample in next pack if we are splitting across packs,
# otherwise return the length of the last sample in the current pack
next_seq_len = (
len(current_pack["input_ids"][boundary:]) if self.split_across_pack else current_pack["seq_lens"][-1]
)
output_dict = {
"input_ids": current_pack["input_ids"][boundary:],
"labels": current_pack["labels"][boundary:],
"position_ids": current_pack["position_ids"][boundary:],
"seq_lens": [next_seq_len],
}
if self.contains_loss_mask:
output_dict["loss_mask"] = current_pack["loss_mask"][boundary:]
return output_dict
[docs]
def _add_pack(self, pack: PACK_TYPE) -> None:
"""
Processes, pads and adds a pack to ``self.packs``.
"""
pack = self._convert_to_tensors(pack)
pack = self._pad_pack(pack, padding_idx=self.padding_idx)
self.packs.append(pack)
[docs]
def _convert_to_tensors(self, pack: PACK_TYPE) -> PACK_TYPE:
"""
Converts a pack into tensors. Pack comes in as a dict of lists and is converted to tensors.
"""
tensor_pack = {
"input_ids": torch.tensor(pack["input_ids"], dtype=torch.long),
"labels": torch.tensor(pack["labels"], dtype=torch.long),
"position_ids": torch.tensor(pack["position_ids"], dtype=torch.long),
"seq_lens": torch.tensor(pack["seq_lens"], dtype=torch.long),
}
if self.contains_loss_mask:
tensor_pack["loss_mask"] = torch.tensor(pack["loss_mask"], dtype=torch.long)
return tensor_pack
[docs]
def _pad_pack(self, pack: PACK_TYPE, padding_idx: int) -> PACK_TYPE:
"""
Pads a pack to ``self.packed_sequence_size``.
"""
# Pad tokens
num_padding_tokens = self.packed_sequence_size - len(pack["input_ids"])
padded_tokens = F.pad(
pack["input_ids"],
(0, num_padding_tokens),
value=padding_idx,
)
# Pad labels
padded_labels = F.pad(
pack["labels"],
(0, self.packed_sequence_size - len(pack["labels"])),
value=CROSS_ENTROPY_IGNORE_IDX,
)
# Pad loss_mask
if self.contains_loss_mask:
padded_loss_mask = F.pad(
pack["loss_mask"],
(0, self.packed_sequence_size - len(pack["loss_mask"])),
value=0,
)
# Add padding tokens as a last seq len to ensure sum is packed_sequence_size
padded_seq_lens = (
torch.cat([pack["seq_lens"], torch.tensor([num_padding_tokens])])
if num_padding_tokens > 0
else pack["seq_lens"]
)
# Pad position_ids continuing the sequence from last value
# in position_ids
# e.g. [0 1 2] -> [0 1 2 3 4 5] for self.packed_sequence_size = 6
num_range = torch.arange(
pack["position_ids"][-1] + 1,
pack["position_ids"][-1] + self.packed_sequence_size - len(pack["position_ids"]) + 1,
)
# Clamp to packed_sequence_size - 1 to avoid out of bounds error
clamped_num_range = torch.clamp(num_range, 0, self.packed_sequence_size - 1)
padded_position_ids = torch.cat([pack["position_ids"], clamped_num_range])
padded_pack = {
"input_ids": padded_tokens,
"labels": padded_labels,
"position_ids": padded_position_ids,
"seq_lens": padded_seq_lens,
}
if self.contains_loss_mask:
padded_pack["loss_mask"] = padded_loss_mask
return padded_pack
[docs]
def create_block_causal_mask(seq_lens: list[torch.Tensor]) -> torch.Tensor:
"""
Creates causal mask block for specified lengths.
In particular, given a batch tensor of seq lens defining the lengths of samples in each pack,
Construct a 2D block causal mask for each pack in the batch. For example, if
a single sample's seq_lens is [3, 2, 1], the mask would be::
mask = [
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[0, 0, 0, 1, 0, 0],
[0, 0, 0, 1, 1, 0],
[0, 0, 0, 0, 0, 1],
]
Args:
seq_lens (List[torch.Tensor]): Sequence lengths of samples in each pack in the batch,
shape (batch_size, n), where n is the max number of sequences in a pack and can vary
across packs.
Returns:
Tensor: Block causal mask of shape (batch_size, packed_sequence_size, packed_sequence_size).
"""
batch_block_attn_masks = []
batch_size = len(seq_lens)
for sample_idx in range(batch_size):
block_attn_masks = [
torch.tril(
torch.ones(
seq_len,
seq_len,
dtype=torch.bool,
),
)
for i, seq_len in enumerate(seq_lens[sample_idx])
]
batch_block_attn_masks.append(torch.block_diag(*block_attn_masks))
# Transformers expects the attn_mask to be 4d [bs, 1, packed_sequence_size, packed_sequence_size], hence adding
# singleton (size 1) dimension at position 1.
return torch.stack(batch_block_attn_masks).unsqueeze(1)
[docs]
def packed_block_causal_mask(seq_lens: list[torch.Tensor]):
"""
Create a 2D block causal document mask for a batch of packed sequences.
Args:
seq_lens (List[torch.Tensor]): Sequence lengths of samples in each pack in the batch,
shape (batch_size, n), where n is the max number of sequences in a pack and can vary
across packs.
Returns:
_MaskType: BlockMask or Tensor if torch version < 2.5.0.
"""
return create_block_causal_mask(seq_lens=seq_lens)