nemo_rl.algorithms.utils#

Module Contents#

Functions#

get_gdpo_reward_component_keys

Return batch keys that are reward components (reward1, reward2, …) in sorted order.

calculate_kl

Calculates a per-token estimate of the KL Divergence between two logprobs.

calculate_baseline_and_std_per_prompt

Function to compute a baseline for each (prompt, response) pair in the batch.

surpress_user_warnings

masked_mean

Computes the mean of a microbatch, using a global statistic as the normalization factor.

mask_out_neg_inf_logprobs

Mask out negative infinity log probabilities.

set_seed

Sets the seed for python, numpy, and pytorch.

get_tokenizer

Get the tokenizer and set pad token to eos token if it is not already set.

maybe_pad_last_batch

Pads the given batch so that its size is divisible by (mbs * dp_size).

print_performance_metrics

Print performance metrics for GRPO.

log_generation_metrics_to_wandb

Log generation metrics to wandb.

API#

nemo_rl.algorithms.utils.get_gdpo_reward_component_keys(batch) list#

Return batch keys that are reward components (reward1, reward2, …) in sorted order.

nemo_rl.algorithms.utils.calculate_kl(
logprobs: torch.Tensor,
logprobs_reference: torch.Tensor,
kl_type: str = 'k3',
input_clamp_value: float | None = 20.0,
output_clamp_value: float | None = 10.0,
) torch.Tensor#

Calculates a per-token estimate of the KL Divergence between two logprobs.

From Schulman 2020, http://joschu.net/blog/kl-approx.html.

Parameters:
  • logprobs – torch.Tensor (b, s)

  • logprobs_reference – torch.Tensor (b, s)

  • kl_type – Type of KL approximation to use. Valid values: “k1”, “k2”, “k3”.

  • input_clamp_value – Optional clamping value for logr to prevent numerical instability. If None, no clamping is applied.

  • output_clamp_value – Optional clamping value for kl to prevent numerical instability. If None, no clamping is applied.

Returns:

Per-token KL penalty values (b, s)

Return type:

torch.Tensor

nemo_rl.algorithms.utils.calculate_baseline_and_std_per_prompt(
prompts: torch.Tensor,
rewards: torch.Tensor,
valid_mask: torch.Tensor,
leave_one_out_baseline: bool = True,
std_rewards: torch.Tensor | None = None,
) tuple[torch.Tensor, torch.Tensor]#

Function to compute a baseline for each (prompt, response) pair in the batch.

The same baseline is calculated for each prompt. Samples set to 0 in ‘valid_mask’ are not included in the baseline calculation.

prompts: tensor (b, s) Tensor of prompts the model used. May be on any device rewards: tensor (b,) Float-valued rewards. May be on any device valid_mask: tensor (b,) Vector of 0/1, where 0 is to ignore and 1 is to keep leave_one_out_baseline: bool Compute an unbiased baseline by leaving out the sample that the baseline is for (from RLOO https://arxiv.org/abs/2402.14740) std_rewards: tensor (b,) Optional separate reward tensor used only for the std calculation. Defaults to rewards. Useful for DAPO, which needs std on the raw task metric for dynamic sampling filtering while keeping baseline on the shaped reward.

Returns: tensor (b,), tensor (b,) of baselines and std on the same device as ‘rewards’

nemo_rl.algorithms.utils.surpress_user_warnings(f)#
nemo_rl.algorithms.utils.masked_mean(
values: torch.Tensor,
mask: torch.Tensor,
dim: Optional[int] = None,
global_normalization_factor: Optional[torch.Tensor | float] = None,
)#

Computes the mean of a microbatch, using a global statistic as the normalization factor.

nemo_rl.algorithms.utils.mask_out_neg_inf_logprobs(
logprobs: torch.Tensor,
mask: torch.Tensor,
logprobs_name: str,
) torch.Tensor#

Mask out negative infinity log probabilities.

Handling sampling mask mismatch: vLLM samples token X from top-k/p filtered distribution -> generation_logprobs[X] is always finite (e.g., -5.41) during training: policy computes logprobs with same top-k/p settings, but the distribution can be slightly different token X may fall outside the training policy’s top-k/p set -> curr_logprobs[X] = -inf, prev_logprobs[X] = -inf Detect positions with -inf in any logprobs (generation_logprobs is always finite for valid tokens)

Parameters:
  • logprobs – Log probabilities.

  • mask – Mask.

  • logprobs_name – Name of the logprobs tensor. Used for printing warning messages.

Returns:

Masked log probabilities.

nemo_rl.algorithms.utils.set_seed(seed: int) None#

Sets the seed for python, numpy, and pytorch.

nemo_rl.algorithms.utils.get_tokenizer(
tokenizer_config: nemo_rl.models.policy.TokenizerConfig,
get_processor: bool = False,
) transformers.PreTrainedTokenizerBase#

Get the tokenizer and set pad token to eos token if it is not already set.

This function initializes a tokenizer from the Hugging Face transformers library and configures it with appropriate chat templates and padding tokens.

Parameters:
  • tokenizer_config – A dictionary containing tokenizer configuration. Required keys: - name: The name or path of the pretrained tokenizer Optional keys: - chat_template: The chat template to use. Can be: - None: Uses a passthrough template that just returns message content - “default”: Uses the tokenizer’s default template - A custom jinja2 template string If not specified, the tokenizer’s default template will be used.

  • get_processor – Whether to return a processor (via AutoProcessor) instead of a tokenizer.

Returns:

The configured tokenizer instance

Return type:

PreTrainedTokenizerBase

.. rubric:: Examples

>>> from transformers import AutoTokenizer
>>> from nemo_rl.algorithms.utils import get_tokenizer
>>> # not specifying a chat template uses the tokenizer's default
>>> config = {"name": "meta-llama/Llama-3.2-1B-Instruct"}
>>> tokenizer = get_tokenizer(config)
No chat template provided, using tokenizer's default
>>> messages = [
...     {"role": "system", "content": "You are a helpful AI assistant."},
...     {"role": "user", "content": "Hello!"}
... ]
>>> formatted = tokenizer.apply_chat_template(messages, tokenize=False)
>>> assert formatted == AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct").apply_chat_template(messages, tokenize=False)

>>> # Using a passthrough template
>>> config = {
...     "name": "meta-llama/Llama-3.2-1B-Instruct",
...     "chat_template": None
... }
>>> tokenizer = get_tokenizer(config)
Using passthrough chat template
>>> formatted = tokenizer.apply_chat_template(messages, tokenize=False)
>>> assert formatted == "".join(msg["content"] for msg in messages)

>>> # Using a custom template
>>> config = {
...     "name": "meta-llama/Llama-3.2-1B-Instruct",
...     "chat_template": "{% for message in messages %}{{ ' START: ' + message['content'] + ' END.' }}{% endfor %}"
... }
>>> tokenizer = get_tokenizer(config)
Using custom chat template
>>> formatted = tokenizer.apply_chat_template(messages, tokenize=False)
>>> assert formatted == " START: You are a helpful AI assistant. END. START: Hello! END."

>>> # Requesting a processor (for multimodal models like Qwen-VL)
>>> config = {"name": "Qwen/Qwen2.5-VL-3B-Instruct"}
>>> processor = get_tokenizer(config, get_processor=True)
No chat template provided, using tokenizer's default
>>> messages = [
...     {"role": "system", "content": "You are a helpful AI assistant."},
...     {"role": "user", "content": "Hello!"}
... ]
>>> formatted = processor.tokenizer.apply_chat_template(messages, tokenize=False)
>>> assert formatted == AutoTokenizer.from_pretrained(
...     "Qwen/Qwen2.5-VL-3B-Instruct", trust_remote_code=True
... ).apply_chat_template(messages, tokenize=False)
>>> assert processor.pad_token_id == processor.tokenizer.pad_token_id
>>>
nemo_rl.algorithms.utils.maybe_pad_last_batch(batch: dict, dp_size: int, mbs: int) dict#

Pads the given batch so that its size is divisible by (mbs * dp_size).

Parameters:
  • batch (dict) – The batch to pad.

  • dp_size (int) – Data parallel size.

  • mbs (int) – Micro batch size.

Returns:

The padded batch.

Return type:

dict

nemo_rl.algorithms.utils.print_performance_metrics(
train_results: dict[str, float],
metrics: dict[str, Any],
timing_metrics: dict[str, float],
master_config: dict,
) dict[str, float]#

Print performance metrics for GRPO.

nemo_rl.algorithms.utils.log_generation_metrics_to_wandb(
generation_logger_metrics: dict[str, dict[int, list[Any]]],
step: int,
timeline_interval: float,
logger: nemo_rl.utils.logger.Logger,
) None#

Log generation metrics to wandb.

Parameters:
  • generation_logger_metrics – Dictionary of generation logger metrics

  • step – Global step value

  • timeline_interval – Interval between timeline points (in seconds)

  • logger – Logger instance