nemo_rl.algorithms.loss_functions#

Module Contents#

Classes#

ClippedPGLossConfig

ClippedPGLossDataDict

Required keys for the Clipped Policy Gradient loss function.

ClippedPGLossFn

Generalized Clipped Policy Gradient loss function w/ KL regularization.

NLLLoss

Negative Log Likelihood Loss function.

PreferenceLossDataDict

Required keys for the preference loss function.

PreferenceLoss

Preference Loss function.

DPOLossConfig

DPOLossDataDict

Required keys for the DPO loss function.

DPOLossFn

Direct Preference Optimization (DPO) loss function.

SequencePackingLossWrapper

Data#

API#

nemo_rl.algorithms.loss_functions.Tensor#

‘TypeVar(…)’

class nemo_rl.algorithms.loss_functions.ClippedPGLossConfig#

Bases: typing.TypedDict

reference_policy_kl_penalty: float#

None

ratio_clip_min: float#

None

ratio_clip_max: float#

None

ratio_clip_c: float#

None

use_on_policy_kl_approximation: bool#

None

use_importance_sampling_correction: bool#

None

token_level_loss: bool#

None

class nemo_rl.algorithms.loss_functions.ClippedPGLossDataDict#

Bases: typing.TypedDict

Required keys for the Clipped Policy Gradient loss function.

Initialization

Initialize self. See help(type(self)) for accurate signature.

input_ids: torch.Tensor#

None

advantages: torch.Tensor#

None

prev_logprobs: torch.Tensor#

None

generation_logprobs: torch.Tensor#

None

reference_policy_logprobs: torch.Tensor#

None

token_mask: torch.Tensor#

None

sample_mask: torch.Tensor#

None

__extra__: Any#

None

class nemo_rl.algorithms.loss_functions.ClippedPGLossFn(
cfg: nemo_rl.algorithms.loss_functions.ClippedPGLossConfig,
)#

Bases: nemo_rl.algorithms.interfaces.LossFunction

Generalized Clipped Policy Gradient loss function w/ KL regularization.

This implements:

  • PPO (Clipped) - https://arxiv.org/abs/1707.06347

  • GRPO - https://arxiv.org/abs/2402.03300

  • REINFORCE/RLOO (set disable_ppo_ratio = True and ignores ratio_clip_min/ratio_clip_max) - https://arxiv.org/abs/2402.14740

Formula: L(θ) = E_t [ min(r_t(θ) * A_t, clip(r_t(θ), 1-ε, 1+ε) * A_t) ] - β * KL(π_θ || π_ref)

where:

  • r_t(θ) = π_θ(a_t|s_t) / π_θ_old(a_t|s_t) is the probability ratio

  • A_t is the advantage estimate

  • ε is the clip parameter (ratio_clip_min/ratio_clip_max)

    • As proposed in the DAPO paper (https://arxiv.org/pdf/2503.14476), we allow setting a distinct minimum and maximum value for the clip parameter (set to the same value for PPO/GRPO/etc.)

      • ratio_clip_min: minimum value for the clip parameter

      • ratio_clip_max: maximum value for the clip parameter

  • β is the KL penalty coefficient (reference_policy_kl_penalty)

  • KL(π_θ || π_ref) is the KL divergence between the current policy and reference policy (Schulman Approx.)

For REINFORCE/RLOO (when disable_ppo_ratio=True), the formula simplifies to: L(θ) = E_t [ π_θ(a_t|s_t) * A_t ] - β * KL(π_θ || π_ref)

Also supports “Dual-Clipping” from https://arxiv.org/pdf/1912.09729, which imposes an additional upper bound on the probability ratio when advantages are negative. This prevents excessive policy updates. \(rA << 0\) -> \(cA\)(clipped) The loss function is modified to the following when A_t < 0: L(θ) = E_t [ max(min(r_t(θ) * A_t, clip(r_t(θ), 1-ε, 1+ε) * A_t), c * A_t) ] - β * KL(π_θ || π_ref)

where:

  • c is the dual-clip parameter (ratio_clip_c), which must be greater than 1 and is usually set as 3 empirically.

Due to potential numerical instability, we cast the logits to float32 before computing the loss.

Initialization

__call__(
next_token_logits: nemo_rl.algorithms.loss_functions.Tensor,
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.algorithms.loss_functions.ClippedPGLossDataDict],
global_valid_seqs: torch.Tensor,
global_valid_toks: torch.Tensor,
vocab_parallel_rank: Optional[int] = None,
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
) tuple[torch.Tensor, dict]#

Clipped Policy Gradient RL loss function.

class nemo_rl.algorithms.loss_functions.NLLLoss#

Bases: nemo_rl.algorithms.interfaces.LossFunction

Negative Log Likelihood Loss function.

loss_type#

None

__call__(
next_token_logits: nemo_rl.algorithms.loss_functions.Tensor,
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
global_valid_seqs: nemo_rl.algorithms.loss_functions.Tensor | None,
global_valid_toks: nemo_rl.algorithms.loss_functions.Tensor,
vocab_parallel_rank: Optional[int] = None,
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
dpo_loss: bool = False,
dpo_average_log_probs: bool = False,
) tuple[torch.Tensor, dict[str, Any]]#
class nemo_rl.algorithms.loss_functions.PreferenceLossDataDict#

Bases: typing.TypedDict

Required keys for the preference loss function.

Initialization

Initialize self. See help(type(self)) for accurate signature.

input_ids: torch.Tensor#

None

token_mask: torch.Tensor#

None

sample_mask: torch.Tensor#

None

class nemo_rl.algorithms.loss_functions.PreferenceLoss#

Bases: nemo_rl.algorithms.interfaces.LossFunction

Preference Loss function.

Optimizes the model to prefer chosen responses over rejected ones

The preference loss is computed as: L_pref(θ) = -E[log(σ(β * (r_chosen - r_rejected)))]

where:

  • σ is the sigmoid function

  • β is a scaling factor (ex: reference_policy_kl_penalty in DPO)

  • r_chosen and r_rejected are the rewards for chosen and rejected responses

Returns:

A tuple containing: - The preference loss value - A dictionary with metrics including: - loss: Preference loss - accuracy: Fraction of examples where chosen response has higher reward

Return type:

tuple[torch.Tensor, dict]

Initialization

split_output_tensor(
tensor: nemo_rl.algorithms.loss_functions.Tensor,
) tuple[nemo_rl.algorithms.loss_functions.Tensor, nemo_rl.algorithms.loss_functions.Tensor]#
_preference_loss(
rewards: nemo_rl.algorithms.loss_functions.Tensor,
sample_mask: nemo_rl.algorithms.loss_functions.Tensor,
global_valid_seqs: nemo_rl.algorithms.loss_functions.Tensor,
beta: float = 1.0,
) tuple[nemo_rl.algorithms.loss_functions.Tensor, nemo_rl.algorithms.loss_functions.Tensor, nemo_rl.algorithms.loss_functions.Tensor, nemo_rl.algorithms.loss_functions.Tensor]#
__call__(
rewards: nemo_rl.algorithms.loss_functions.Tensor,
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.algorithms.loss_functions.PreferenceLossDataDict],
global_valid_seqs: nemo_rl.algorithms.loss_functions.Tensor,
global_valid_toks: nemo_rl.algorithms.loss_functions.Tensor | None,
) tuple[torch.Tensor, dict[str, Any]]#
class nemo_rl.algorithms.loss_functions.DPOLossConfig#

Bases: typing.TypedDict

reference_policy_kl_penalty: float#

None

preference_loss_weight: float#

None

sft_loss_weight: float#

None

preference_average_log_probs: bool#

None

sft_average_log_probs: bool#

None

class nemo_rl.algorithms.loss_functions.DPOLossDataDict#

Bases: typing.TypedDict

Required keys for the DPO loss function.

Initialization

Initialize self. See help(type(self)) for accurate signature.

input_ids: torch.Tensor#

None

reference_policy_logprobs: torch.Tensor#

None

token_mask: torch.Tensor#

None

sample_mask: torch.Tensor#

None

class nemo_rl.algorithms.loss_functions.DPOLossFn(cfg: nemo_rl.algorithms.loss_functions.DPOLossConfig)#

Bases: nemo_rl.algorithms.loss_functions.PreferenceLoss

Direct Preference Optimization (DPO) loss function.

This loss function implements the DPO algorithm as described in: “Direct Preference Optimization: Your Language Model is Secretly a Reward Model” (https://arxiv.org/abs/2305.18290)

The loss combines two main components:

  1. Preference Loss: Optimizes the model to prefer chosen responses over rejected ones

  2. SFT Loss (optional): Auxiliary supervised fine-tuning loss on chosen responses

The total loss is computed as: L(θ) = w_p * L_pref(θ) + w_s * L_sft(θ)

where:

  • w_p is the preference_loss_weight

  • w_s is the sft_loss_weight

  • L_pref(θ) is the preference loss term

  • L_sft(θ) is the supervised fine-tuning loss term

The preference loss term is computed as: L_pref(θ) = -E[log(σ(β * (r_chosen - r_rejected)))]

where:

  • σ is the sigmoid function

  • β is the reference_policy_kl_penalty

  • r_chosen and r_rejected are the rewards for chosen and rejected responses

  • The rewards are computed as the sum of log probability differences between the current policy and reference policy

If preference_average_log_probs is True, the rewards are averaged over tokens: r = (1/n) * Σ_t (log π_θ(a_t|s_t) - log π_ref(a_t|s_t))

Otherwise, the rewards are summed over tokens.

The SFT loss term is a standard negative log likelihood loss on the chosen responses. If sft_average_log_probs is True, the loss is averaged over tokens.

Parameters:

cfg (DPOLossConfig) –

Configuration dictionary containing:

  • reference_policy_kl_penalty (float): Strength of the KL penalty term (β)

  • preference_loss_weight (float): Weight for the preference loss term (w_p)

  • sft_loss_weight (float): Weight for the SFT loss term (w_s)

  • preference_average_log_probs (bool): Whether to average log probs across tokens in preference loss

  • sft_average_log_probs (bool): Whether to average log probs across tokens in SFT loss

Returns:

A tuple containing: - The total loss value - A dictionary with metrics including: - loss: Total loss value - sft_loss: SFT loss component - preference_loss: Preference loss component - accuracy: Fraction of examples where chosen response has higher reward

Return type:

tuple[torch.Tensor, dict]

Initialization

_dpo_loss(
next_token_logits: nemo_rl.algorithms.loss_functions.Tensor,
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.algorithms.loss_functions.DPOLossDataDict],
global_valid_seqs: nemo_rl.algorithms.loss_functions.Tensor,
vocab_parallel_rank: Optional[int] = None,
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
) tuple[nemo_rl.algorithms.loss_functions.Tensor, nemo_rl.algorithms.loss_functions.Tensor, nemo_rl.algorithms.loss_functions.Tensor, nemo_rl.algorithms.loss_functions.Tensor]#
__call__(
next_token_logits: nemo_rl.algorithms.loss_functions.Tensor,
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.algorithms.loss_functions.DPOLossDataDict],
global_valid_seqs: nemo_rl.algorithms.loss_functions.Tensor,
global_valid_toks: nemo_rl.algorithms.loss_functions.Tensor | None,
vocab_parallel_rank: Optional[int] = None,
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
) tuple[torch.Tensor, dict[str, Any]]#
class nemo_rl.algorithms.loss_functions.SequencePackingLossWrapper(
loss_fn: nemo_rl.algorithms.interfaces.LossFunction,
cu_seqlens_q: nemo_rl.algorithms.loss_functions.Tensor,
cu_seqlens_q_padded: Optional[nemo_rl.algorithms.loss_functions.Tensor] = None,
)#

Initialization

__call__(
next_token_logits: nemo_rl.algorithms.loss_functions.Tensor,
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
global_valid_seqs: nemo_rl.algorithms.loss_functions.Tensor | None,
global_valid_toks: nemo_rl.algorithms.loss_functions.Tensor | None,
vocab_parallel_rank: Optional[int] = None,
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
) tuple[nemo_rl.algorithms.loss_functions.Tensor, dict[str, Any]]#

Wraps a loss function to handle sequence packing by doing one sequence at a time to avoid excessive padding.