nemo_rl.data.llm_message_utils#

Module Contents#

Functions#

message_log_to_flat_messages

Converts a message log (sequence of message turns) into a flattened representation.

get_keys_from_message_log

Return a new LLMMessageLogType containing only the specified keys from each message.

add_loss_mask_to_message_log

Add token-level loss masks to each message in a message log.

_pad_tensor

Pad a tensor to the specified length.

_validate_tensor_consistency

Validate that all tensors have consistent dtypes and devices.

batched_message_log_to_flat_message

Process and pad a batch of message logs for model input.

message_log_shape

Get the shape of the tensors in the message log.

get_first_index_that_differs

Get the first index that differs between two strings.

get_formatted_message_log

Format and tokenize chat messages using the specified template.

remap_dataset_keys

Remap dataset keys as per mapping.

API#

nemo_rl.data.llm_message_utils.message_log_to_flat_messages(
message_log: nemo_rl.data.interfaces.LLMMessageLogType,
) nemo_rl.data.interfaces.FlatMessagesType[source]#

Converts a message log (sequence of message turns) into a flattened representation.

This function takes a message log (list of dict messages with ‘role’, ‘content’, ‘token_ids’, etc.) and converts it to a flat dictionary where all tensors of the same key are concatenated and all strings of the same key are put into lists.

Parameters:

message_log – List of message dictionaries with ‘role’, ‘content’, and potentially ‘token_ids’

Returns:

Dictionary mapping keys to concatenated tensors and string lists

Return type:

FlatMessagesType

Examples:

>>> import torch
>>> from nemo_rl.data.llm_message_utils import message_log_to_flat_messages
>>> # Create a simple message log with two messages
>>> message_log = [
...     {'role': 'user', 'content': 'Hello', 'token_ids': torch.tensor([1, 2, 3])},
...     {'role': 'assistant', 'content': 'Hi there', 'token_ids': torch.tensor([4, 5, 6, 7])}
... ]
>>> flat_msgs = message_log_to_flat_messages(message_log)
>>> flat_msgs['role']
['user', 'assistant']
>>> flat_msgs['content']
['Hello', 'Hi there']
>>> flat_msgs['token_ids']
tensor([1, 2, 3, 4, 5, 6, 7])
nemo_rl.data.llm_message_utils.get_keys_from_message_log(
message_log: nemo_rl.data.interfaces.LLMMessageLogType,
keys: List[str],
) nemo_rl.data.interfaces.LLMMessageLogType[source]#

Return a new LLMMessageLogType containing only the specified keys from each message.

Parameters:
  • message_log – Original message log to extract keys from

  • keys – List of keys to keep in each message

Returns:

New list with only specified keys

Return type:

LLMMessageLogType

nemo_rl.data.llm_message_utils.add_loss_mask_to_message_log(
message_log: nemo_rl.data.interfaces.LLMMessageLogType,
roles_to_train_on: List[str] = ['assistant'],
only_unmask_final: bool = False,
) None[source]#

Add token-level loss masks to each message in a message log.

Parameters:
  • message_log (LLMMessageLogType) – List of message dictionaries containing token IDs and metadata

  • roles_to_train_on (List[str]) – List of strings indicating which speakers to unmask. Default: [“assistant”]

  • only_unmask_final (bool) – If True, only unmask the final message in the log. Default: False

nemo_rl.data.llm_message_utils._pad_tensor(
tensor: torch.Tensor,
max_len: int,
pad_side: str,
pad_value: int = 0,
) torch.Tensor[source]#

Pad a tensor to the specified length.

Parameters:
  • tensor – Tensor to pad

  • max_len – Length to pad to

  • pad_side – Whether to pad on the ‘left’ or ‘right’

  • pad_value – Value to use for padding

Returns:

Padded tensor

Return type:

torch.Tensor

nemo_rl.data.llm_message_utils._validate_tensor_consistency(
tensors: List[torch.Tensor],
) None[source]#

Validate that all tensors have consistent dtypes and devices.

Parameters:

tensors – List of tensors to validate

Raises:

RuntimeError – If tensors have different dtypes or devices

nemo_rl.data.llm_message_utils.batched_message_log_to_flat_message(
message_log_batch: List[nemo_rl.data.interfaces.LLMMessageLogType],
pad_value_dict: Dict[str, int] = None,
make_sequence_length_divisible_by: int = 1,
) tuple[nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.FlatMessagesType], torch.Tensor][source]#

Process and pad a batch of message logs for model input.

For each message log in the batch:

  1. Converts it to a flat representation using message_log_to_flat_messages

  2. Pads all resulting tensors to the same length for batching

  3. Returns a BatchedDataDict and sequence lengths tensor

Padding is always applied to the right side of sequences.

Parameters:
  • message_log_batch – List of LLMMessageLogType (each a conversation with multiple turns)

  • pad_value_dict – Dictionary mapping keys to padding values (default is 0)

  • make_sequence_length_divisible_by – forces the data to be divisible by this value

Returns:

Dictionary containing padded stacked tensors torch.Tensor: Input lengths tensor with shape [batch_size] (pre-padding lengths)

Return type:

BatchedDataDict[FlatMessagesType]

Raises:

RuntimeError – If tensors have different dtypes or devices

Examples:

>>> import torch
>>> from nemo_rl.data.llm_message_utils import batched_message_log_to_flat_message
>>> from nemo_rl.distributed.batched_data_dict import BatchedDataDict
>>> # Create a batch of two message logs with different lengths
>>> message_log_batch = [
...     # First conversation
...     [
...         {'role': 'user', 'content': 'What is 2+2?', 'token_ids': torch.tensor([1, 2, 3, 4, 5])},
...         {'role': 'assistant', 'content': '4', 'token_ids': torch.tensor([6, 7])}
...     ],
...     # Second conversation
...     [
...         {'role': 'user', 'content': 'Solve x+10=15', 'token_ids': torch.tensor([1, 8, 9, 10, 11, 12])},
...         {'role': 'assistant', 'content': 'x=5', 'token_ids': torch.tensor([13, 14, 15])}
...     ]
... ]
>>> pad_value_dict = {'token_ids': 0}
>>> batched_flat, input_lengths = batched_message_log_to_flat_message(message_log_batch, pad_value_dict)
>>> batched_flat['token_ids'][0].tolist()
[1, 2, 3, 4, 5, 6, 7, 0, 0]
>>> batched_flat['token_ids'][1].tolist()
[1, 8, 9, 10, 11, 12, 13, 14, 15]
>>> batched_flat['content'][0]
['What is 2+2?', '4']
>>> batched_flat['content'][1]
['Solve x+10=15', 'x=5']
>>> batched_flat['role']
[['user', 'assistant'], ['user', 'assistant']]
>>> input_lengths
tensor([7, 9], dtype=torch.int32)
>>>
nemo_rl.data.llm_message_utils.message_log_shape(
message_log: nemo_rl.data.interfaces.LLMMessageLogType,
) List[Dict[str, List[int]]][source]#

Get the shape of the tensors in the message log.

This utility function examines each message in the message log and reports the shape of tensor values or recursively processes list values.

Parameters:

message_log – The message log to analyze

Returns:

List of dictionaries containing tensor shapes for each key in messages

nemo_rl.data.llm_message_utils.get_first_index_that_differs(str1, str2)[source]#

Get the first index that differs between two strings.

nemo_rl.data.llm_message_utils.get_formatted_message_log(
message_log: nemo_rl.data.interfaces.LLMMessageLogType,
tokenizer,
task_data_spec: nemo_rl.data.interfaces.TaskDataSpec,
add_bos_token: bool = True,
add_eos_token: bool = True,
add_generation_prompt: bool = False,
) nemo_rl.data.interfaces.LLMMessageLogType[source]#

Format and tokenize chat messages using the specified template.

Parameters:
  • message_log – List of message dicts with ‘role’ and ‘content’ keys

  • tokenizer – Tokenizer for converting text to token IDs

  • task_data_spec – Task spec for this dataset.

  • add_bos_token – Whether to add bos token to first message if it is not already present. Default: True

  • add_eos_token – Whether to add eos token to last message if it is not already present. Default: True

  • add_generation_prompt – Whether to include assistant’s generation prompt in user messages. Default: False

Returns:

The message log with updated ‘token_ids’ and ‘content’ fields.

nemo_rl.data.llm_message_utils.remap_dataset_keys(
dataset: datasets.Dataset,
mapping_dict: Dict[str, str],
) datasets.Dataset[source]#

Remap dataset keys as per mapping.

Parameters:
  • dataset – The input dataset to remap keys in

  • mapping_dict – A dictionary mapping input keys to output keys

Returns:

A new dataset with remapped keys

Return type:

Dataset