nemo_rl.data.collate_fn#

Module Contents#

Functions#

rl_collate_fn

Collate function for RL training.

eval_collate_fn

Collate function for evaluation.

preference_collate_fn

Collate function for preference data training.

Data#

API#

nemo_rl.data.collate_fn.TokenizerType#

None

nemo_rl.data.collate_fn.rl_collate_fn(
data_batch: list[nemo_rl.data.interfaces.DatumSpec],
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any][source]#

Collate function for RL training.

nemo_rl.data.collate_fn.eval_collate_fn(
data_batch: list[nemo_rl.data.interfaces.DatumSpec],
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any][source]#

Collate function for evaluation.

Takes a list of data samples and combines them into a single batched dictionary for model evaluation.

Parameters:

data_batch – List of data samples with message_log, extra_env_info, and idx fields.

Returns:

BatchedDataDict with message_log, extra_env_info, and idx fields.

Examples:

>>> import torch
>>> from nemo_rl.data.collate_fn import eval_collate_fn
>>> from nemo_rl.data.interfaces import DatumSpec
>>> data_batch = [
...     DatumSpec(
...         message_log=[{"role": "user", "content": "Hello", "token_ids": torch.tensor([1, 2, 3])}],
...         extra_env_info={'ground_truth': '1'},
...         idx=0,
...     ),
...     DatumSpec(
...         message_log=[{"role": "assistant", "content": "Hi there", "token_ids": torch.tensor([4, 5, 6, 7])}],
...         extra_env_info={'ground_truth': '2'},
...         idx=1,
...     ),
... ]
>>> output = eval_collate_fn(data_batch)
>>> output['message_log'][0]
[{'role': 'user', 'content': 'Hello', 'token_ids': tensor([1, 2, 3])}]
>>> output['message_log'][1]
[{'role': 'assistant', 'content': 'Hi there', 'token_ids': tensor([4, 5, 6, 7])}]
>>> output['extra_env_info']
[{'ground_truth': '1'}, {'ground_truth': '2'}]
>>> output['idx']
[0, 1]
nemo_rl.data.collate_fn.preference_collate_fn(
data_batch: list[nemo_rl.data.interfaces.DPODatumSpec],
tokenizer: nemo_rl.data.collate_fn.TokenizerType,
make_sequence_length_divisible_by: int,
add_loss_mask: bool,
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any][source]#

Collate function for preference data training.

This function separates the chosen and rejected responses to create two examples per prompt. The chosen and rejected examples are interleaved along the batch dimension, resulting in a batch size of 2 * len(data_batch).

Parameters:
  • data_batch – List of data samples with message_log_chosen, message_log_rejected, length_chosen, length_rejected, loss_multiplier, idx, and task_name fields.

  • tokenizer – Tokenizer for text processing

  • make_sequence_length_divisible_by – Make the sequence length divisible by this value

  • add_loss_mask – Whether to add a token_mask to the returned data

Returns:

BatchedDataDict with input_ids, input_lengths, token_mask (optional), and sample_mask fields.