# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from datasets import load_dataset
[docs]
def make_squad_dataset(
tokenizer,
seq_length=None,
limit_dataset_samples=None,
start_of_turn_token=None,
fp8=False,
split="train",
dataset_name="rajpurkar/squad",
):
"""
Load and preprocess a SQuAD-style QA dataset for model fine-tuning.
This function retrieves the specified split of the SQuAD dataset, applies
either a simple prompt–completion format or a chat‐template format
(if `tokenizer.chat_template` is set), tokenizes each example,
constructs `input_ids`, `labels`, and `loss_mask`, and optionally pads
all sequences to a fixed length.
Args:
tokenizer: A Hugging Face tokenizer with attributes
`eos_token_id`, optional `bos_id`, optional `eos_id`, and
optionally `chat_template`/`apply_chat_template`.
seq_length (int, optional): If set, pad/truncate each example to this
length.
limit_dataset_samples (int, optional): If set, limit the number of
examples loaded from the split.
start_of_turn_token (str or None): If using a chat template, the
token that marks the start of each turn. Used to compute the
response offset for `loss_mask`.
fp8 (bool): Flag for future use (e.g., mixed precision). Currently
unused.
split (str): Which split of the dataset to load (e.g. 'train',
'validation').
dataset_name (str): Identifier for the Hugging Face dataset
(default "rajpurkar/squad").
Returns:
A Hugginggth Face Dataset where each example is a dict with keys:
- `input_ids`: List of token IDs for the prompt + answer.
- `labels`: List of token IDs shifted for language modeling.
- `loss_mask`: List of 0/1 flags indicating which tokens contribute
to the loss (answers only).
"""
eos_token_id = getattr(tokenizer, "eos_token_id", 0)
chat_template = getattr(tokenizer, "chat_template", None)
def pad_to_seq_length(sample):
seq_pad_len_ar = max(0, seq_length - len(next(iter(sample.values()))))
return {k: v + [eos_token_id if v != "loss_mask" else 0] * seq_pad_len_ar for k, v in sample.items()}
def formatting_prompts_func(example):
formatted_text = [
f"{example['context']} {example['question']} ",
example["answers"]["text"][0].strip(),
]
context_ids, answer_ids = list(map(lambda x: tokenizer(x)["input_ids"], formatted_text))
bos_id = getattr(tokenizer, "bos_token_id", None)
eos_id = getattr(tokenizer, "eos_token_id", None)
# Remove EOS token from context's end
if len(context_ids) > 0 and context_ids[-1] == eos_id:
context_ids = context_ids[:-1]
# Remove BOS token from answer's start
if len(answer_ids) > 0 and answer_ids[0] == bos_id:
answer_ids = answer_ids[1:]
input_ids = context_ids + answer_ids
return dict(
input_ids=input_ids,
labels=input_ids[1:] + [eos_token_id or input_ids[-1]],
loss_mask=[0] * len(context_ids) + [1] * len(answer_ids),
)
def formatting_prompts_func_with_chat_template(example, start_of_turn_token=None):
formatted_text = [
{"role": "user", "content": f"{example['context']} {example['question']}"},
{"role": "assistant", "content": example["answers"]["text"][0].strip()},
]
input_ids = tokenizer.apply_chat_template(formatted_text)
if isinstance(start_of_turn_token, str):
start_of_turn_token_id = tokenizer(start_of_turn_token, add_special_tokens=False)["input_ids"][0]
first_start_of_turn_token_id = input_ids.index(start_of_turn_token_id)
response_start = input_ids.index(start_of_turn_token_id, first_start_of_turn_token_id + 1) + 1
else:
response_start = 0
loss_mask = [0] * response_start + [1] * (len(input_ids) - response_start)
return dict(
input_ids=input_ids,
labels=input_ids[1:] + [getattr(tokenizer, "eos_token_id", None) or input_ids[-1]],
loss_mask=loss_mask,
)
if limit_dataset_samples is not None:
assert isinstance(limit_dataset_samples, int), "Expected limit_dataset_samples to be an int"
split = f"{split}[:{limit_dataset_samples}]"
dataset = load_dataset(dataset_name, split=split)
fmt_fn = formatting_prompts_func
if chat_template is not None:
fmt_fn = lambda x: formatting_prompts_func_with_chat_template(x, start_of_turn_token) # noqa: E731
if isinstance(seq_length, int):
fmt_fn_ = fmt_fn
fmt_fn = lambda x: pad_to_seq_length(fmt_fn_(x)) # noqa: E731
return dataset.map(
fmt_fn,
batched=False,
remove_columns=["id", "title", "context", "question", "answers"],
)