nemo_rl.data.collate_fn
#
Module Contents#
Functions#
Collate function for RL training. |
|
Collate function for evaluation. |
|
Collate function for preference data training. |
Data#
API#
- nemo_rl.data.collate_fn.TokenizerType#
None
- nemo_rl.data.collate_fn.rl_collate_fn(
- data_batch: list[nemo_rl.data.interfaces.DatumSpec],
Collate function for RL training.
- nemo_rl.data.collate_fn.eval_collate_fn(
- data_batch: list[nemo_rl.data.interfaces.DatumSpec],
Collate function for evaluation.
Takes a list of data samples and combines them into a single batched dictionary for model evaluation.
- Parameters:
data_batch – List of data samples with message_log, extra_env_info, and idx fields.
- Returns:
BatchedDataDict with message_log, extra_env_info, and idx fields.
Examples:
>>> import torch >>> from nemo_rl.data.collate_fn import eval_collate_fn >>> from nemo_rl.data.interfaces import DatumSpec >>> data_batch = [ ... DatumSpec( ... message_log=[{"role": "user", "content": "Hello", "token_ids": torch.tensor([1, 2, 3])}], ... extra_env_info={'ground_truth': '1'}, ... idx=0, ... ), ... DatumSpec( ... message_log=[{"role": "assistant", "content": "Hi there", "token_ids": torch.tensor([4, 5, 6, 7])}], ... extra_env_info={'ground_truth': '2'}, ... idx=1, ... ), ... ] >>> output = eval_collate_fn(data_batch) >>> output['message_log'][0] [{'role': 'user', 'content': 'Hello', 'token_ids': tensor([1, 2, 3])}] >>> output['message_log'][1] [{'role': 'assistant', 'content': 'Hi there', 'token_ids': tensor([4, 5, 6, 7])}] >>> output['extra_env_info'] [{'ground_truth': '1'}, {'ground_truth': '2'}] >>> output['idx'] [0, 1]
- nemo_rl.data.collate_fn.preference_collate_fn(
- data_batch: list[nemo_rl.data.interfaces.DPODatumSpec],
- tokenizer: nemo_rl.data.collate_fn.TokenizerType,
- make_sequence_length_divisible_by: int,
- add_loss_mask: bool,
Collate function for preference data training.
This function separates the chosen and rejected responses to create two examples per prompt. The chosen and rejected examples are interleaved along the batch dimension, resulting in a batch size of 2 * len(data_batch).
- Parameters:
data_batch – List of data samples with message_log_chosen, message_log_rejected, length_chosen, length_rejected, loss_multiplier, idx, and task_name fields.
tokenizer – Tokenizer for text processing
make_sequence_length_divisible_by – Make the sequence length divisible by this value
add_loss_mask – Whether to add a token_mask to the returned data
- Returns:
BatchedDataDict with input_ids, input_lengths, token_mask (optional), and sample_mask fields.