nemo_rl.algorithms.loss_functions#

Module Contents#

Classes#

LossType

ClippedPGLossConfig

ClippedPGLossDataDict

Required keys for the Clipped Policy Gradient loss function.

ClippedPGLossFn

Generalized Clipped Policy Gradient loss function w/ KL regularization.

NLLLoss

Negative Log Likelihood Loss function.

DPOLossConfig

DPOLossDataDict

Required keys for the DPO loss function.

DPOLossFn

Direct Preference Optimization (DPO) loss function.

API#

class nemo_rl.algorithms.loss_functions.LossType(*args, **kwds)[source]#

Bases: enum.Enum

TOKEN_LEVEL#

‘token_level’

SEQUENCE_LEVEL#

‘sequence_level’

class nemo_rl.algorithms.loss_functions.ClippedPGLossConfig[source]#

Bases: typing.TypedDict

reference_policy_kl_penalty: float#

None

ratio_clip_min: float#

None

ratio_clip_max: float#

None

ratio_clip_c: float#

None

use_on_policy_kl_approximation: bool#

None

use_importance_sampling_correction: bool#

None

token_level_loss: bool#

None

class nemo_rl.algorithms.loss_functions.ClippedPGLossDataDict[source]#

Bases: typing.TypedDict

Required keys for the Clipped Policy Gradient loss function.

Initialization

Initialize self. See help(type(self)) for accurate signature.

input_ids: torch.Tensor#

None

advantages: torch.Tensor#

None

prev_logprobs: torch.Tensor#

None

generation_logprobs: torch.Tensor#

None

reference_policy_logprobs: torch.Tensor#

None

token_mask: torch.Tensor#

None

sample_mask: torch.Tensor#

None

__extra__: Any#

None

class nemo_rl.algorithms.loss_functions.ClippedPGLossFn(
cfg: nemo_rl.algorithms.loss_functions.ClippedPGLossConfig,
)[source]#

Bases: nemo_rl.algorithms.interfaces.LossFunction

Generalized Clipped Policy Gradient loss function w/ KL regularization.

This implements:

  • PPO (Clipped) - https://arxiv.org/abs/1707.06347

  • GRPO - https://arxiv.org/abs/2402.03300

  • REINFORCE/RLOO (set disable_ppo_ratio = True and ignores ratio_clip_min/ratio_clip_max) - https://arxiv.org/abs/2402.14740

Formula: L(θ) = E_t [ min(r_t(θ) * A_t, clip(r_t(θ), 1-ε, 1+ε) * A_t) ] - β * KL(π_θ || π_ref)

where:

  • r_t(θ) = π_θ(a_t|s_t) / π_θ_old(a_t|s_t) is the probability ratio

  • A_t is the advantage estimate

  • ε is the clip parameter (ratio_clip_min/ratio_clip_max)

    • As proposed in the DAPO paper (https://arxiv.org/pdf/2503.14476), we allow setting a distinct minimum and maximum value for the clip parameter (set to the same value for PPO/GRPO/etc.)

      • ratio_clip_min: minimum value for the clip parameter

      • ratio_clip_max: maximum value for the clip parameter

  • β is the KL penalty coefficient (reference_policy_kl_penalty)

  • KL(π_θ || π_ref) is the KL divergence between the current policy and reference policy (Schulman Approx.)

For REINFORCE/RLOO (when disable_ppo_ratio=True), the formula simplifies to: L(θ) = E_t [ π_θ(a_t|s_t) * A_t ] - β * KL(π_θ || π_ref)

Also supports “Dual-Clipping” from https://arxiv.org/pdf/1912.09729, which imposes an additional upper bound on the probability ratio when advantages are negative. This prevents excessive policy updates. \(rA << 0\) -> \(cA\)(clipped) The loss function is modified to the following when A_t < 0: L(θ) = E_t [ max(min(r_t(θ) * A_t, clip(r_t(θ), 1-ε, 1+ε) * A_t), c * A_t) ] - β * KL(π_θ || π_ref)

where:

  • c is the dual-clip parameter (ratio_clip_c), which must be greater than 1 and is usually set as 3 empirically.

Due to potential numerical instability, we cast the logits to float32 before computing the loss.

Initialization

__call__(
next_token_logits: torch.Tensor,
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.algorithms.loss_functions.ClippedPGLossDataDict],
global_valid_seqs: torch.Tensor,
global_valid_toks: torch.Tensor,
) Tuple[torch.Tensor, dict][source]#

Clipped Policy Gradient RL loss function.

class nemo_rl.algorithms.loss_functions.NLLLoss[source]#

Bases: nemo_rl.algorithms.interfaces.LossFunction

Negative Log Likelihood Loss function.

loss_type#

None

__call__(
next_token_logits: torch.Tensor,
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict,
global_valid_seqs: torch.Tensor | None,
global_valid_toks: torch.Tensor,
dpo_loss: bool = False,
dpo_average_log_probs: bool = False,
) Tuple[torch.Tensor, dict][source]#
class nemo_rl.algorithms.loss_functions.DPOLossConfig[source]#

Bases: typing.TypedDict

reference_policy_kl_penalty: float#

None

preference_loss_weight: float#

1.0

sft_loss_weight: float#

0.0

preference_average_log_probs: bool#

False

sft_average_log_probs: bool#

False

class nemo_rl.algorithms.loss_functions.DPOLossDataDict[source]#

Bases: typing.TypedDict

Required keys for the DPO loss function.

Initialization

Initialize self. See help(type(self)) for accurate signature.

input_ids: torch.Tensor#

None

reference_policy_logprobs: torch.Tensor#

None

token_mask: torch.Tensor#

None

sample_mask: torch.Tensor#

None

class nemo_rl.algorithms.loss_functions.DPOLossFn(cfg: nemo_rl.algorithms.loss_functions.DPOLossConfig)[source]#

Bases: nemo_rl.algorithms.interfaces.LossFunction

Direct Preference Optimization (DPO) loss function.

This loss function implements the DPO algorithm as described in: “Direct Preference Optimization: Your Language Model is Secretly a Reward Model” (https://arxiv.org/abs/2305.18290)

The loss combines two main components:

  1. Preference Loss: Optimizes the model to prefer chosen responses over rejected ones

  2. SFT Loss (optional): Auxiliary supervised fine-tuning loss on chosen responses

The total loss is computed as: L(θ) = w_p * L_pref(θ) + w_s * L_sft(θ)

where:

  • w_p is the preference_loss_weight

  • w_s is the sft_loss_weight

  • L_pref(θ) is the preference loss term

  • L_sft(θ) is the supervised fine-tuning loss term

The preference loss term is computed as: L_pref(θ) = -E[log(σ(β * (r_chosen - r_rejected)))]

where:

  • σ is the sigmoid function

  • β is the reference_policy_kl_penalty

  • r_chosen and r_rejected are the rewards for chosen and rejected responses

  • The rewards are computed as the sum of log probability differences between the current policy and reference policy

If preference_average_log_probs is True, the rewards are averaged over tokens: r = (1/n) * Σ_t (log π_θ(a_t|s_t) - log π_ref(a_t|s_t))

Otherwise, the rewards are summed over tokens.

The SFT loss term is a standard negative log likelihood loss on the chosen responses. If sft_average_log_probs is True, the loss is averaged over tokens.

Parameters:

cfg (DPOLossConfig) –

Configuration dictionary containing:

  • reference_policy_kl_penalty (float): Strength of the KL penalty term (β)

  • preference_loss_weight (float): Weight for the preference loss term (w_p)

  • sft_loss_weight (float): Weight for the SFT loss term (w_s)

  • preference_average_log_probs (bool): Whether to average log probs across tokens in preference loss

  • sft_average_log_probs (bool): Whether to average log probs across tokens in SFT loss

Returns:

A tuple containing: - The total loss value - A dictionary with metrics including: - loss: Total loss value - sft_loss: SFT loss component - preference_loss: Preference loss component - accuracy: Fraction of examples where chosen response has higher reward

Return type:

Tuple[torch.Tensor, dict]

Initialization

split_output_tensor(tensor: torch.Tensor)[source]#
preference_loss(
next_token_logits: torch.Tensor,
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.algorithms.loss_functions.DPOLossDataDict],
global_valid_seqs: torch.Tensor,
) torch.Tensor[source]#
__call__(
next_token_logits: torch.Tensor,
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.algorithms.loss_functions.DPOLossDataDict],
global_valid_seqs: torch.Tensor,
global_valid_toks: torch.Tensor | None,
) Tuple[torch.Tensor, dict][source]#