Source code for nemo_rl.data.hf_datasets.oasst
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import gzip
import json
import os
import random
import requests
from nemo_rl.data.interfaces import TaskDataSpec
SYSTEM_PROMPT = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n\n"
[docs]
def parse_conversations(tree_obj, first=False):
"""Recusive function that returns all the sub converstaions in a list starting from node tree_obj.
Args:
tree_obj (obj): current conversation node
Returns:
a list of sub conversation threads including the current conversation node
"""
turns = []
if first:
turn = {"content": SYSTEM_PROMPT, "role": "system"}
turns.append(turn)
if "prompt" in tree_obj:
prompt_obj = tree_obj["prompt"]
elif "text" in tree_obj and "role" in tree_obj:
prompt_obj = tree_obj
else:
return [[]]
if prompt_obj["role"] == "prompter":
role = "user"
elif prompt_obj["role"] == "assistant":
role = "assistant"
else:
raise ValueError(f"unknown role {prompt_obj['role']}")
turn = {"content": prompt_obj["text"], "role": role}
turns.append(turn)
all_conversations = []
multiple_sub_threads = []
for next_obj in prompt_obj["replies"]:
multiple_threads = parse_conversations(next_obj)
multiple_sub_threads.extend(multiple_threads)
if len(multiple_sub_threads) != 0:
for sub_thread in multiple_sub_threads:
all_conversations.append(copy.deepcopy(turns) + sub_thread)
else:
all_conversations.append(copy.deepcopy(turns))
return all_conversations
[docs]
def get_data_records(objs):
## TODO: old format was multi-conversation per example, but ours is single conversation
## is this just because of the input data format?
output = []
for obj in objs:
multi_conversations = parse_conversations(obj, first=True)
for conversations in multi_conversations:
if len(conversations) <= 2:
# remove single turn conversations
## system prompt is always first turn
continue
conversation_obj = {
"messages": conversations,
}
output.append(conversation_obj)
return output
[docs]
def download_and_process_oasst(output_directory=".", seed=42, split_ratio=0.95):
os.makedirs(output_directory, exist_ok=True)
filename = f"{output_directory}/2023-04-12_oasst_all.trees.jsonl.gz"
# only download if doesn't exist
if not os.path.isfile(filename):
url = "https://huggingface.co/datasets/OpenAssistant/oasst1/resolve/main/2023-04-12_oasst_all.trees.jsonl.gz"
response = requests.get(url)
with open(filename, mode="wb") as fw:
fw.write(response.content)
with gzip.open(filename) as f:
file_content = f.readlines()
all_objs = [json.loads(dp.decode("utf-8")) for dp in file_content]
random.seed(seed)
random.shuffle(all_objs)
train_num = int(len(all_objs) * split_ratio)
train_objs = all_objs[:train_num]
val_objs = all_objs[train_num:]
train_records = get_data_records(train_objs)
val_records = get_data_records(val_objs)
formatted_ds = {
"train": train_records,
"validation": val_records,
}
return formatted_ds
[docs]
class OasstDataset:
def __init__(self, output_dir: str = "."):
self.formatted_ds = download_and_process_oasst(output_dir)
self.task_spec = TaskDataSpec(
task_name="OASST",
)