nemo_rl.algorithms.utils
#
Module Contents#
Functions#
Calculates a per-token estimate of the KL Divergence between two log_probs. |
|
Function to compute a baseline for each (prompt, response) pair in the batch. |
|
Computes the mean of a microbatch, using a global statistic as the normalization factor. |
|
Sets the seed for python, numpy, and pytorch. |
|
Get the tokenizer and set pad token to eos token if it is not already set. |
API#
- nemo_rl.algorithms.utils.calculate_kl_penalty_joschu2020(
- logprobs_policy: torch.Tensor,
- logprobs_reference: torch.Tensor,
Calculates a per-token estimate of the KL Divergence between two log_probs.
From Schulman 2020, always positive.
logprobs_policy: torch.Tensor (b, s) logprobs_reference: torch.Tensor (b, s)
- nemo_rl.algorithms.utils.calculate_baseline_and_std_per_prompt(
- prompts: torch.Tensor,
- rewards: torch.Tensor,
- valid_mask: torch.Tensor,
- leave_one_out_baseline: bool = True,
Function to compute a baseline for each (prompt, response) pair in the batch.
The same baseline is calculated for each prompt. Samples set to 0 in ‘valid_mask’ are not included in the baseline calculation.
prompts: tensor (b, s) Tensor of prompts the model used. May be on any device rewards: tensor (b,) Float-valued rewards. May be on any device valid_mask: tensor (b,) Vector of 0/1, where 0 is to ignore and 1 is to keep leave_one_out_baseline: bool Compute an unbiased baseline by leaving out the sample that the baseline is for (from RLOO https://arxiv.org/abs/2402.14740)
Returns: tensor (b,) of baselines on the same device as ‘rewards’
- nemo_rl.algorithms.utils.masked_mean(
- values,
- mask,
- dim: Optional[int] = None,
- global_normalization_factor: Optional[torch.Tensor | float] = None,
Computes the mean of a microbatch, using a global statistic as the normalization factor.
- nemo_rl.algorithms.utils.get_tokenizer(
- tokenizer_config: nemo_rl.models.policy.TokenizerConfig,
Get the tokenizer and set pad token to eos token if it is not already set.
This function initializes a tokenizer from the Hugging Face transformers library and configures it with appropriate chat templates and padding tokens.
- Parameters:
tokenizer_config – A dictionary containing tokenizer configuration. Required keys: - name: The name or path of the pretrained tokenizer Optional keys: - chat_template: The chat template to use. Can be: - None: Uses a passthrough template that just returns message content - “default”: Uses the tokenizer’s default template - A custom jinja2 template string If not specified, the tokenizer’s default template will be used.
- Returns:
The configured tokenizer instance
- Return type:
AutoTokenizer
.. rubric:: Examples
>>> from transformers import AutoTokenizer >>> from nemo_rl.algorithms.utils import get_tokenizer >>> # not specifying a chat template uses the tokenizer's default >>> config = {"name": "meta-llama/Llama-3.2-1B-Instruct"} >>> tokenizer = get_tokenizer(config) No chat template provided, using tokenizer's default >>> messages = [ ... {"role": "system", "content": "You are a helpful AI assistant."}, ... {"role": "user", "content": "Hello!"} ... ] >>> formatted = tokenizer.apply_chat_template(messages, tokenize=False) >>> assert formatted == AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct").apply_chat_template(messages, tokenize=False) >>> # Using a passthrough template >>> config = { ... "name": "meta-llama/Llama-3.2-1B-Instruct", ... "chat_template": None ... } >>> tokenizer = get_tokenizer(config) Using passthrough chat template >>> formatted = tokenizer.apply_chat_template(messages, tokenize=False) >>> assert formatted == "".join(msg["content"] for msg in messages) >>> # Using a custom template >>> config = { ... "name": "meta-llama/Llama-3.2-1B-Instruct", ... "chat_template": "{% for message in messages %}{{ ' START: ' + message['content'] + ' END.' }}{% endfor %}" ... } >>> tokenizer = get_tokenizer(config) Using custom chat template >>> formatted = tokenizer.apply_chat_template(messages, tokenize=False) >>> assert formatted == " START: You are a helpful AI assistant. END. START: Hello! END."