nemo_rl.algorithms.utils#

Module Contents#

Functions#

calculate_kl_penalty_joschu2020

Calculates a per-token estimate of the KL Divergence between two log_probs.

calculate_baseline_and_std_per_prompt

Function to compute a baseline for each (prompt, response) pair in the batch.

surpress_user_warnings

masked_mean

Computes the mean of a microbatch, using a global statistic as the normalization factor.

set_seed

Sets the seed for python, numpy, and pytorch.

get_tokenizer

Get the tokenizer and set pad token to eos token if it is not already set.

API#

nemo_rl.algorithms.utils.calculate_kl_penalty_joschu2020(
logprobs_policy: torch.Tensor,
logprobs_reference: torch.Tensor,
)[source]#

Calculates a per-token estimate of the KL Divergence between two log_probs.

From Schulman 2020, always positive.

logprobs_policy: torch.Tensor (b, s) logprobs_reference: torch.Tensor (b, s)

nemo_rl.algorithms.utils.calculate_baseline_and_std_per_prompt(
prompts: torch.Tensor,
rewards: torch.Tensor,
valid_mask: torch.Tensor,
leave_one_out_baseline: bool = True,
)[source]#

Function to compute a baseline for each (prompt, response) pair in the batch.

The same baseline is calculated for each prompt. Samples set to 0 in ‘valid_mask’ are not included in the baseline calculation.

prompts: tensor (b, s) Tensor of prompts the model used. May be on any device rewards: tensor (b,) Float-valued rewards. May be on any device valid_mask: tensor (b,) Vector of 0/1, where 0 is to ignore and 1 is to keep leave_one_out_baseline: bool Compute an unbiased baseline by leaving out the sample that the baseline is for (from RLOO https://arxiv.org/abs/2402.14740)

Returns: tensor (b,) of baselines on the same device as ‘rewards’

nemo_rl.algorithms.utils.surpress_user_warnings(f)[source]#
nemo_rl.algorithms.utils.masked_mean(
values,
mask,
dim: Optional[int] = None,
global_normalization_factor: Optional[torch.Tensor | float] = None,
)[source]#

Computes the mean of a microbatch, using a global statistic as the normalization factor.

nemo_rl.algorithms.utils.set_seed(seed: int)[source]#

Sets the seed for python, numpy, and pytorch.

nemo_rl.algorithms.utils.get_tokenizer(
tokenizer_config: nemo_rl.models.policy.TokenizerConfig,
) transformers.AutoTokenizer[source]#

Get the tokenizer and set pad token to eos token if it is not already set.

This function initializes a tokenizer from the Hugging Face transformers library and configures it with appropriate chat templates and padding tokens.

Parameters:

tokenizer_config – A dictionary containing tokenizer configuration. Required keys: - name: The name or path of the pretrained tokenizer Optional keys: - chat_template: The chat template to use. Can be: - None: Uses a passthrough template that just returns message content - “default”: Uses the tokenizer’s default template - A custom jinja2 template string If not specified, the tokenizer’s default template will be used.

Returns:

The configured tokenizer instance

Return type:

AutoTokenizer

.. rubric:: Examples

>>> from transformers import AutoTokenizer
>>> from nemo_rl.algorithms.utils import get_tokenizer
>>> # not specifying a chat template uses the tokenizer's default
>>> config = {"name": "meta-llama/Llama-3.2-1B-Instruct"}
>>> tokenizer = get_tokenizer(config)
No chat template provided, using tokenizer's default
>>> messages = [
...     {"role": "system", "content": "You are a helpful AI assistant."},
...     {"role": "user", "content": "Hello!"}
... ]
>>> formatted = tokenizer.apply_chat_template(messages, tokenize=False)
>>> assert formatted == AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct").apply_chat_template(messages, tokenize=False)

>>> # Using a passthrough template
>>> config = {
...     "name": "meta-llama/Llama-3.2-1B-Instruct",
...     "chat_template": None
... }
>>> tokenizer = get_tokenizer(config)
Using passthrough chat template
>>> formatted = tokenizer.apply_chat_template(messages, tokenize=False)
>>> assert formatted == "".join(msg["content"] for msg in messages)

>>> # Using a custom template
>>> config = {
...     "name": "meta-llama/Llama-3.2-1B-Instruct",
...     "chat_template": "{% for message in messages %}{{ ' START: ' + message['content'] + ' END.' }}{% endfor %}"
... }
>>> tokenizer = get_tokenizer(config)
Using custom chat template
>>> formatted = tokenizer.apply_chat_template(messages, tokenize=False)
>>> assert formatted == " START: You are a helpful AI assistant. END. START: Hello! END."