nemo_rl.data.datasets
#
Module Contents#
Classes#
Dataset for processing single or multi-task data with task-specific tokenization and processing. |
Functions#
Collate function for RL training. |
|
Collate function for evaluation. |
|
Collate function for DPO training. |
Data#
API#
- nemo_rl.data.datasets.TokenizerType#
None
- class nemo_rl.data.datasets.AllTaskProcessedDataset(
- dataset: datasets.Dataset | Any,
- tokenizer: nemo_rl.data.datasets.TokenizerType,
- default_task_data_spec: nemo_rl.data.interfaces.TaskDataSpec,
- task_data_processors: dict[str, tuple[nemo_rl.data.interfaces.TaskDataSpec, nemo_rl.data.interfaces.TaskDataProcessFnCallable]] | nemo_rl.data.interfaces.TaskDataProcessFnCallable,
- max_seq_length: Optional[int] = None,
Dataset for processing single or multi-task data with task-specific tokenization and processing.
- Parameters:
dataset – Input dataset containing raw data
tokenizer – Tokenizer for text processing
default_task_data_spec – Default task processing specifications. In the case of single-task, this is the spec used for processing all entries. In the case of multi-task, any values not specified in the task-specific specs will be taken from the default spec.
task_data_processors – Either a single TaskDataProcessFnCallable for single-task, or a dict mapping task names to (TaskDataSpec, TaskDataProcessFnCallable) for multi-task
max_seq_length – Maximum sequence length for tokenized outputs
Initialization
- encode_single(
- text: Union[str, list[str]],
Takes either a single string or a list of strings that represent multiple turns for the same conversation.
Returns a single (concatenated) list of tokenized ids and the length of the tokenized ids.
- __getitem__(idx: int) nemo_rl.data.interfaces.DatumSpec [source]#
Return a single prompt.
- nemo_rl.data.datasets.rl_collate_fn(
- data_batch: list[nemo_rl.data.interfaces.DatumSpec],
Collate function for RL training.
- nemo_rl.data.datasets.eval_collate_fn(
- data_batch: list[nemo_rl.data.interfaces.DatumSpec],
Collate function for evaluation.
Takes a list of data samples and combines them into a single batched dictionary for model evaluation.
- Parameters:
data_batch – List of data samples with message_log, extra_env_info, and idx fields.
- Returns:
BatchedDataDict with message_log, extra_env_info, and idx fields.
Examples:
>>> import torch >>> from nemo_rl.data.datasets import eval_collate_fn >>> from nemo_rl.data.interfaces import DatumSpec >>> data_batch = [ ... DatumSpec( ... message_log=[{"role": "user", "content": "Hello", "token_ids": torch.tensor([1, 2, 3])}], ... extra_env_info={'ground_truth': '1'}, ... idx=0, ... ), ... DatumSpec( ... message_log=[{"role": "assistant", "content": "Hi there", "token_ids": torch.tensor([4, 5, 6, 7])}], ... extra_env_info={'ground_truth': '2'}, ... idx=1, ... ), ... ] >>> output = eval_collate_fn(data_batch) >>> output['message_log'][0] [{'role': 'user', 'content': 'Hello', 'token_ids': tensor([1, 2, 3])}] >>> output['message_log'][1] [{'role': 'assistant', 'content': 'Hi there', 'token_ids': tensor([4, 5, 6, 7])}] >>> output['extra_env_info'] [{'ground_truth': '1'}, {'ground_truth': '2'}] >>> output['idx'] [0, 1]
- nemo_rl.data.datasets.dpo_collate_fn(
- data_batch: list[nemo_rl.data.interfaces.DPODatumSpec],
- tokenizer: nemo_rl.data.datasets.TokenizerType,
- make_sequence_length_divisible_by: int,
Collate function for DPO training.
This function separates the chosen and rejected responses to create two examples per prompt. The chosen and rejected examples are interleaved along the batch dimension, resulting in a batch size of 2 * len(data_batch).