nemo_rl.algorithms.loss_functions#

Module Contents#

Classes#

ClippedPGLossConfig

ClippedPGLossDataDict

Required keys for the Clipped Policy Gradient loss function.

ClippedPGLossFn

Generalized Clipped Policy Gradient loss function w/ KL regularization.

NLLLoss

Negative Log Likelihood Loss function.

DPOLossConfig

DPOLossDataDict

Required keys for the DPO loss function.

DPOLossFn

Direct Preference Optimization (DPO) loss function.

SequencePackingLossWrapper

Data#

API#

nemo_rl.algorithms.loss_functions.Tensor#

‘TypeVar(…)’

class nemo_rl.algorithms.loss_functions.ClippedPGLossConfig[source]#

Bases: typing.TypedDict

reference_policy_kl_penalty: float#

None

ratio_clip_min: float#

None

ratio_clip_max: float#

None

ratio_clip_c: float#

None

use_on_policy_kl_approximation: bool#

None

use_importance_sampling_correction: bool#

None

token_level_loss: bool#

None

class nemo_rl.algorithms.loss_functions.ClippedPGLossDataDict[source]#

Bases: typing.TypedDict

Required keys for the Clipped Policy Gradient loss function.

Initialization

Initialize self. See help(type(self)) for accurate signature.

input_ids: torch.Tensor#

None

advantages: torch.Tensor#

None

prev_logprobs: torch.Tensor#

None

generation_logprobs: torch.Tensor#

None

reference_policy_logprobs: torch.Tensor#

None

token_mask: torch.Tensor#

None

sample_mask: torch.Tensor#

None

__extra__: Any#

None

class nemo_rl.algorithms.loss_functions.ClippedPGLossFn(
cfg: nemo_rl.algorithms.loss_functions.ClippedPGLossConfig,
)[source]#

Bases: nemo_rl.algorithms.interfaces.LossFunction

Generalized Clipped Policy Gradient loss function w/ KL regularization.

This implements:

  • PPO (Clipped) - https://arxiv.org/abs/1707.06347

  • GRPO - https://arxiv.org/abs/2402.03300

  • REINFORCE/RLOO (set disable_ppo_ratio = True and ignores ratio_clip_min/ratio_clip_max) - https://arxiv.org/abs/2402.14740

Formula: L(θ) = E_t [ min(r_t(θ) * A_t, clip(r_t(θ), 1-ε, 1+ε) * A_t) ] - β * KL(π_θ || π_ref)

where:

  • r_t(θ) = π_θ(a_t|s_t) / π_θ_old(a_t|s_t) is the probability ratio

  • A_t is the advantage estimate

  • ε is the clip parameter (ratio_clip_min/ratio_clip_max)

    • As proposed in the DAPO paper (https://arxiv.org/pdf/2503.14476), we allow setting a distinct minimum and maximum value for the clip parameter (set to the same value for PPO/GRPO/etc.)

      • ratio_clip_min: minimum value for the clip parameter

      • ratio_clip_max: maximum value for the clip parameter

  • β is the KL penalty coefficient (reference_policy_kl_penalty)

  • KL(π_θ || π_ref) is the KL divergence between the current policy and reference policy (Schulman Approx.)

For REINFORCE/RLOO (when disable_ppo_ratio=True), the formula simplifies to: L(θ) = E_t [ π_θ(a_t|s_t) * A_t ] - β * KL(π_θ || π_ref)

Also supports “Dual-Clipping” from https://arxiv.org/pdf/1912.09729, which imposes an additional upper bound on the probability ratio when advantages are negative. This prevents excessive policy updates. \(rA << 0\) -> \(cA\)(clipped) The loss function is modified to the following when A_t < 0: L(θ) = E_t [ max(min(r_t(θ) * A_t, clip(r_t(θ), 1-ε, 1+ε) * A_t), c * A_t) ] - β * KL(π_θ || π_ref)

where:

  • c is the dual-clip parameter (ratio_clip_c), which must be greater than 1 and is usually set as 3 empirically.

Due to potential numerical instability, we cast the logits to float32 before computing the loss.

Initialization

__call__(
next_token_logits: nemo_rl.algorithms.loss_functions.Tensor,
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.algorithms.loss_functions.ClippedPGLossDataDict],
global_valid_seqs: torch.Tensor,
global_valid_toks: torch.Tensor,
vocab_parallel_rank: Optional[int] = None,
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
) tuple[torch.Tensor, dict][source]#

Clipped Policy Gradient RL loss function.

class nemo_rl.algorithms.loss_functions.NLLLoss[source]#

Bases: nemo_rl.algorithms.interfaces.LossFunction

Negative Log Likelihood Loss function.

loss_type#

None

__call__(
next_token_logits: nemo_rl.algorithms.loss_functions.Tensor,
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
global_valid_seqs: nemo_rl.algorithms.loss_functions.Tensor | None,
global_valid_toks: nemo_rl.algorithms.loss_functions.Tensor,
vocab_parallel_rank: Optional[int] = None,
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
dpo_loss: bool = False,
dpo_average_log_probs: bool = False,
) tuple[torch.Tensor, dict[str, Any]][source]#
class nemo_rl.algorithms.loss_functions.DPOLossConfig[source]#

Bases: typing.TypedDict

reference_policy_kl_penalty: float#

None

preference_loss_weight: float#

None

sft_loss_weight: float#

None

preference_average_log_probs: bool#

None

sft_average_log_probs: bool#

None

class nemo_rl.algorithms.loss_functions.DPOLossDataDict[source]#

Bases: typing.TypedDict

Required keys for the DPO loss function.

Initialization

Initialize self. See help(type(self)) for accurate signature.

input_ids: torch.Tensor#

None

reference_policy_logprobs: torch.Tensor#

None

token_mask: torch.Tensor#

None

sample_mask: torch.Tensor#

None

class nemo_rl.algorithms.loss_functions.DPOLossFn(cfg: nemo_rl.algorithms.loss_functions.DPOLossConfig)[source]#

Bases: nemo_rl.algorithms.interfaces.LossFunction

Direct Preference Optimization (DPO) loss function.

This loss function implements the DPO algorithm as described in: “Direct Preference Optimization: Your Language Model is Secretly a Reward Model” (https://arxiv.org/abs/2305.18290)

The loss combines two main components:

  1. Preference Loss: Optimizes the model to prefer chosen responses over rejected ones

  2. SFT Loss (optional): Auxiliary supervised fine-tuning loss on chosen responses

The total loss is computed as: L(θ) = w_p * L_pref(θ) + w_s * L_sft(θ)

where:

  • w_p is the preference_loss_weight

  • w_s is the sft_loss_weight

  • L_pref(θ) is the preference loss term

  • L_sft(θ) is the supervised fine-tuning loss term

The preference loss term is computed as: L_pref(θ) = -E[log(σ(β * (r_chosen - r_rejected)))]

where:

  • σ is the sigmoid function

  • β is the reference_policy_kl_penalty

  • r_chosen and r_rejected are the rewards for chosen and rejected responses

  • The rewards are computed as the sum of log probability differences between the current policy and reference policy

If preference_average_log_probs is True, the rewards are averaged over tokens: r = (1/n) * Σ_t (log π_θ(a_t|s_t) - log π_ref(a_t|s_t))

Otherwise, the rewards are summed over tokens.

The SFT loss term is a standard negative log likelihood loss on the chosen responses. If sft_average_log_probs is True, the loss is averaged over tokens.

Parameters:

cfg (DPOLossConfig) –

Configuration dictionary containing:

  • reference_policy_kl_penalty (float): Strength of the KL penalty term (β)

  • preference_loss_weight (float): Weight for the preference loss term (w_p)

  • sft_loss_weight (float): Weight for the SFT loss term (w_s)

  • preference_average_log_probs (bool): Whether to average log probs across tokens in preference loss

  • sft_average_log_probs (bool): Whether to average log probs across tokens in SFT loss

Returns:

A tuple containing: - The total loss value - A dictionary with metrics including: - loss: Total loss value - sft_loss: SFT loss component - preference_loss: Preference loss component - accuracy: Fraction of examples where chosen response has higher reward

Return type:

tuple[torch.Tensor, dict]

Initialization

split_output_tensor(
tensor: nemo_rl.algorithms.loss_functions.Tensor,
) tuple[nemo_rl.algorithms.loss_functions.Tensor, nemo_rl.algorithms.loss_functions.Tensor][source]#
_preference_loss(
next_token_logits: nemo_rl.algorithms.loss_functions.Tensor,
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.algorithms.loss_functions.DPOLossDataDict],
global_valid_seqs: nemo_rl.algorithms.loss_functions.Tensor,
vocab_parallel_rank: Optional[int] = None,
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
) tuple[nemo_rl.algorithms.loss_functions.Tensor, nemo_rl.algorithms.loss_functions.Tensor, nemo_rl.algorithms.loss_functions.Tensor, nemo_rl.algorithms.loss_functions.Tensor][source]#
__call__(
next_token_logits: nemo_rl.algorithms.loss_functions.Tensor,
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.algorithms.loss_functions.DPOLossDataDict],
global_valid_seqs: nemo_rl.algorithms.loss_functions.Tensor,
global_valid_toks: nemo_rl.algorithms.loss_functions.Tensor | None,
vocab_parallel_rank: Optional[int] = None,
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
) tuple[torch.Tensor, dict[str, Any]][source]#
class nemo_rl.algorithms.loss_functions.SequencePackingLossWrapper(
loss_fn: nemo_rl.algorithms.interfaces.LossFunction,
cu_seqlens_q: nemo_rl.algorithms.loss_functions.Tensor,
cu_seqlens_q_padded: Optional[nemo_rl.algorithms.loss_functions.Tensor] = None,
)[source]#

Initialization

__call__(
next_token_logits: nemo_rl.algorithms.loss_functions.Tensor,
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
global_valid_seqs: nemo_rl.algorithms.loss_functions.Tensor | None,
global_valid_toks: nemo_rl.algorithms.loss_functions.Tensor | None,
vocab_parallel_rank: Optional[int] = None,
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
) tuple[nemo_rl.algorithms.loss_functions.Tensor, dict[str, Any]][source]#

Wraps a loss function to handle sequence packing by doing one sequence at a time to avoid excessive padding.