nemo_rl.data.datasets
#
Module Contents#
Classes#
Dataset for processing single or multi-task data with task-specific tokenization and processing. |
Functions#
Collate function for RL training. |
|
Collate function for evaluation. |
|
Collate function for preference data training. |
|
Collate function for DPO training. |
|
Assert that there are no double starting BOS tokens in the message. |
Data#
API#
- nemo_rl.data.datasets.TokenizerType#
None
- class nemo_rl.data.datasets.AllTaskProcessedDataset(
- dataset: datasets.Dataset | Any,
- tokenizer: nemo_rl.data.datasets.TokenizerType,
- default_task_data_spec: nemo_rl.data.interfaces.TaskDataSpec,
- task_data_processors: dict[str, tuple[nemo_rl.data.interfaces.TaskDataSpec, nemo_rl.data.interfaces.TaskDataProcessFnCallable]] | nemo_rl.data.interfaces.TaskDataProcessFnCallable,
- max_seq_length: Optional[int] = None,
Dataset for processing single or multi-task data with task-specific tokenization and processing.
- Parameters:
dataset – Input dataset containing raw data
tokenizer – Tokenizer for text processing
default_task_data_spec – Default task processing specifications. In the case of single-task, this is the spec used for processing all entries. In the case of multi-task, any values not specified in the task-specific specs will be taken from the default spec.
task_data_processors – Either a single TaskDataProcessFnCallable for single-task, or a dict mapping task names to (TaskDataSpec, TaskDataProcessFnCallable) for multi-task
max_seq_length – Maximum sequence length for tokenized outputs
Initialization
- __len__() int #
- encode_single(
- text: Union[str, list[str]],
Takes either a single string or a list of strings that represent multiple turns for the same conversation.
Returns a single (concatenated) list of tokenized ids and the length of the tokenized ids.
- __getitem__(idx: int) nemo_rl.data.interfaces.DatumSpec #
Return a single prompt.
- nemo_rl.data.datasets.rl_collate_fn(
- data_batch: list[nemo_rl.data.interfaces.DatumSpec],
Collate function for RL training.
- nemo_rl.data.datasets.eval_collate_fn(
- data_batch: list[nemo_rl.data.interfaces.DatumSpec],
Collate function for evaluation.
Takes a list of data samples and combines them into a single batched dictionary for model evaluation.
- Parameters:
data_batch – List of data samples with message_log, extra_env_info, and idx fields.
- Returns:
BatchedDataDict with message_log, extra_env_info, and idx fields.
Examples:
>>> import torch >>> from nemo_rl.data.datasets import eval_collate_fn >>> from nemo_rl.data.interfaces import DatumSpec >>> data_batch = [ ... DatumSpec( ... message_log=[{"role": "user", "content": "Hello", "token_ids": torch.tensor([1, 2, 3])}], ... extra_env_info={'ground_truth': '1'}, ... idx=0, ... ), ... DatumSpec( ... message_log=[{"role": "assistant", "content": "Hi there", "token_ids": torch.tensor([4, 5, 6, 7])}], ... extra_env_info={'ground_truth': '2'}, ... idx=1, ... ), ... ] >>> output = eval_collate_fn(data_batch) >>> output['message_log'][0] [{'role': 'user', 'content': 'Hello', 'token_ids': tensor([1, 2, 3])}] >>> output['message_log'][1] [{'role': 'assistant', 'content': 'Hi there', 'token_ids': tensor([4, 5, 6, 7])}] >>> output['extra_env_info'] [{'ground_truth': '1'}, {'ground_truth': '2'}] >>> output['idx'] [0, 1]
- nemo_rl.data.datasets.preference_collate_fn(
- data_batch: list[nemo_rl.data.interfaces.DPODatumSpec],
Collate function for preference data training.
This function separates the chosen and rejected responses to create two examples per prompt. The chosen and rejected examples are interleaved along the batch dimension, resulting in a batch size of 2 * len(data_batch).
- Parameters:
data_batch – List of data samples with message_log_chosen, message_log_rejected, length_chosen, length_rejected, loss_multiplier, idx, and task_name fields.
- Returns:
BatchedDataDict with message_log, length, loss_multiplier, task_name, and idx fields.
- nemo_rl.data.datasets.dpo_collate_fn(
- data_batch: list[nemo_rl.data.interfaces.DPODatumSpec],
- tokenizer: nemo_rl.data.datasets.TokenizerType,
- make_sequence_length_divisible_by: int,
Collate function for DPO training.
- Parameters:
data_batch – List of data samples with message_log_chosen, message_log_rejected, length_chosen, length_rejected, loss_multiplier, idx, and task_name fields.
tokenizer – Tokenizer for text processing
make_sequence_length_divisible_by – Make the sequence length divisible by this value
- Returns:
BatchedDataDict with input_ids, input_lengths, token_mask, and sample_mask fields.
- nemo_rl.data.datasets.assert_no_double_bos(
- token_ids: torch.Tensor,
- tokenizer: nemo_rl.data.datasets.TokenizerType,
Assert that there are no double starting BOS tokens in the message.
- Parameters:
token_ids – List of token IDs
tokenizer – Tokenizer