Source code for nemo_automodel.datasets.llm.hellaswag
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from datasets import load_dataset
from nemo_automodel.datasets.utils import SFTSingleTurnPreprocessor
[docs]
class HellaSwag:
"""A dataset wrapper for the HellaSwag benchmark, tailored for single-turn supervised fine-tuning (SFT).
This class loads and preprocesses the HellaSwag dataset using a tokenizer and a custom preprocessing
pipeline for language model fine-tuning. The dataset consists of context and multiple-choice endings,
where the goal is to choose the most plausible continuation.
Attributes:
dataset (Dataset): The processed dataset ready for model training or evaluation.
"""
def __init__(self, path_or_dataset, tokenizer, split="train", num_samples_limit=None, trust_remote_code=True):
"""Initialize the HellaSwag dataset wrapper.
Args:
path_or_dataset (str or Dataset): Path to the dataset or a HuggingFace Dataset object.
tokenizer (PreTrainedTokenizer): The tokenizer used to process text.
split (str, optional): Dataset split to use (e.g., 'train', 'validation'). Defaults to 'train'.
num_samples_limit (int, optional): Maximum number of samples to load. Defaults to None.
trust_remote_code (bool, optional): Whether to trust remote code. Defaults to True.
Notes:
If num_samples_limit is an integer, it limits the dataset size using slicing.
"""
if isinstance(num_samples_limit, int):
split = f"{split}[:{num_samples_limit}]"
raw_datasets = load_dataset(path_or_dataset, split=split, trust_remote_code=trust_remote_code)
processor = SFTSingleTurnPreprocessor(tokenizer)
self.dataset = processor.process(raw_datasets, self)
[docs]
def get_context(self, examples):
"""Extracts the context part of each example.
Args:
examples (dict): A dictionary containing example data with a "ctx" key.
Returns:
list[str]: List of context strings.
"""
return examples["ctx"]
[docs]
def get_target(self, examples):
"""Extracts the correct ending based on the label.
Args:
examples (dict): A dictionary with "endings" (list of strings) and "label" (index of correct ending).
Returns:
list[str]: The gold target strings based on the label index.
"""
return [
endings[int(lbl)] for endings, lbl in zip(examples["endings"], examples["label"], strict=False)
]
[docs]
def __getitem__(self, index):
"""Get a processed example by index.
Args:
index (int): Index of the example.
Returns:
dict: A tokenized and preprocessed example.
"""
ans = self.dataset[index]
ans.pop("attention_mask", None)
return ans
[docs]
def __len__(self):
"""Get the number of examples in the dataset.
Returns:
int: Length of the processed dataset.
"""
return len(self.dataset)