nemo_rl.algorithms.loss_functions
#
Module Contents#
Classes#
Required keys for the Clipped Policy Gradient loss function. |
|
Generalized Clipped Policy Gradient loss function w/ KL regularization. |
|
Negative Log Likelihood Loss function. |
|
Required keys for the DPO loss function. |
|
Direct Preference Optimization (DPO) loss function. |
API#
- class nemo_rl.algorithms.loss_functions.LossType(*args, **kwds)[source]#
Bases:
enum.Enum
- TOKEN_LEVEL#
‘token_level’
- SEQUENCE_LEVEL#
‘sequence_level’
- class nemo_rl.algorithms.loss_functions.ClippedPGLossConfig[source]#
Bases:
typing.TypedDict
- reference_policy_kl_penalty: float#
None
- ratio_clip_min: float#
None
- ratio_clip_max: float#
None
- ratio_clip_c: float#
None
- use_on_policy_kl_approximation: bool#
None
- use_importance_sampling_correction: bool#
None
- token_level_loss: bool#
None
- class nemo_rl.algorithms.loss_functions.ClippedPGLossDataDict[source]#
Bases:
typing.TypedDict
Required keys for the Clipped Policy Gradient loss function.
Initialization
Initialize self. See help(type(self)) for accurate signature.
- input_ids: torch.Tensor#
None
- advantages: torch.Tensor#
None
- prev_logprobs: torch.Tensor#
None
- generation_logprobs: torch.Tensor#
None
- reference_policy_logprobs: torch.Tensor#
None
- token_mask: torch.Tensor#
None
- sample_mask: torch.Tensor#
None
- __extra__: Any#
None
- class nemo_rl.algorithms.loss_functions.ClippedPGLossFn( )[source]#
Bases:
nemo_rl.algorithms.interfaces.LossFunction
Generalized Clipped Policy Gradient loss function w/ KL regularization.
This implements:
PPO (Clipped) - https://arxiv.org/abs/1707.06347
GRPO - https://arxiv.org/abs/2402.03300
REINFORCE/RLOO (set disable_ppo_ratio = True and ignores ratio_clip_min/ratio_clip_max) - https://arxiv.org/abs/2402.14740
Formula: L(θ) = E_t [ min(r_t(θ) * A_t, clip(r_t(θ), 1-ε, 1+ε) * A_t) ] - β * KL(π_θ || π_ref)
where:
r_t(θ) = π_θ(a_t|s_t) / π_θ_old(a_t|s_t) is the probability ratio
A_t is the advantage estimate
ε is the clip parameter (ratio_clip_min/ratio_clip_max)
As proposed in the DAPO paper (https://arxiv.org/pdf/2503.14476), we allow setting a distinct minimum and maximum value for the clip parameter (set to the same value for PPO/GRPO/etc.)
ratio_clip_min: minimum value for the clip parameter
ratio_clip_max: maximum value for the clip parameter
β is the KL penalty coefficient (reference_policy_kl_penalty)
KL(π_θ || π_ref) is the KL divergence between the current policy and reference policy (Schulman Approx.)
For REINFORCE/RLOO (when disable_ppo_ratio=True), the formula simplifies to: L(θ) = E_t [ π_θ(a_t|s_t) * A_t ] - β * KL(π_θ || π_ref)
Also supports “Dual-Clipping” from https://arxiv.org/pdf/1912.09729, which imposes an additional upper bound on the probability ratio when advantages are negative. This prevents excessive policy updates. \(rA << 0\) -> \(cA\)(clipped) The loss function is modified to the following when A_t < 0: L(θ) = E_t [ max(min(r_t(θ) * A_t, clip(r_t(θ), 1-ε, 1+ε) * A_t), c * A_t) ] - β * KL(π_θ || π_ref)
where:
c is the dual-clip parameter (ratio_clip_c), which must be greater than 1 and is usually set as 3 empirically.
Due to potential numerical instability, we cast the logits to float32 before computing the loss.
Initialization
- __call__(
- next_token_logits: torch.Tensor,
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.algorithms.loss_functions.ClippedPGLossDataDict],
- global_valid_seqs: torch.Tensor,
- global_valid_toks: torch.Tensor,
Clipped Policy Gradient RL loss function.
- class nemo_rl.algorithms.loss_functions.NLLLoss[source]#
Bases:
nemo_rl.algorithms.interfaces.LossFunction
Negative Log Likelihood Loss function.
- loss_type#
None
- __call__(
- next_token_logits: torch.Tensor,
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict,
- global_valid_seqs: torch.Tensor | None,
- global_valid_toks: torch.Tensor,
- dpo_loss: bool = False,
- dpo_average_log_probs: bool = False,
- class nemo_rl.algorithms.loss_functions.DPOLossConfig[source]#
Bases:
typing.TypedDict
- reference_policy_kl_penalty: float#
None
- preference_loss_weight: float#
1.0
- sft_loss_weight: float#
0.0
- preference_average_log_probs: bool#
False
- sft_average_log_probs: bool#
False
- class nemo_rl.algorithms.loss_functions.DPOLossDataDict[source]#
Bases:
typing.TypedDict
Required keys for the DPO loss function.
Initialization
Initialize self. See help(type(self)) for accurate signature.
- input_ids: torch.Tensor#
None
- reference_policy_logprobs: torch.Tensor#
None
- token_mask: torch.Tensor#
None
- sample_mask: torch.Tensor#
None
- class nemo_rl.algorithms.loss_functions.DPOLossFn(cfg: nemo_rl.algorithms.loss_functions.DPOLossConfig)[source]#
Bases:
nemo_rl.algorithms.interfaces.LossFunction
Direct Preference Optimization (DPO) loss function.
This loss function implements the DPO algorithm as described in: “Direct Preference Optimization: Your Language Model is Secretly a Reward Model” (https://arxiv.org/abs/2305.18290)
The loss combines two main components:
Preference Loss: Optimizes the model to prefer chosen responses over rejected ones
SFT Loss (optional): Auxiliary supervised fine-tuning loss on chosen responses
The total loss is computed as: L(θ) = w_p * L_pref(θ) + w_s * L_sft(θ)
where:
w_p is the preference_loss_weight
w_s is the sft_loss_weight
L_pref(θ) is the preference loss term
L_sft(θ) is the supervised fine-tuning loss term
The preference loss term is computed as: L_pref(θ) = -E[log(σ(β * (r_chosen - r_rejected)))]
where:
σ is the sigmoid function
β is the reference_policy_kl_penalty
r_chosen and r_rejected are the rewards for chosen and rejected responses
The rewards are computed as the sum of log probability differences between the current policy and reference policy
If preference_average_log_probs is True, the rewards are averaged over tokens: r = (1/n) * Σ_t (log π_θ(a_t|s_t) - log π_ref(a_t|s_t))
Otherwise, the rewards are summed over tokens.
The SFT loss term is a standard negative log likelihood loss on the chosen responses. If sft_average_log_probs is True, the loss is averaged over tokens.
- Parameters:
cfg (DPOLossConfig) –
Configuration dictionary containing:
reference_policy_kl_penalty (float): Strength of the KL penalty term (β)
preference_loss_weight (float): Weight for the preference loss term (w_p)
sft_loss_weight (float): Weight for the SFT loss term (w_s)
preference_average_log_probs (bool): Whether to average log probs across tokens in preference loss
sft_average_log_probs (bool): Whether to average log probs across tokens in SFT loss
- Returns:
A tuple containing: - The total loss value - A dictionary with metrics including: - loss: Total loss value - sft_loss: SFT loss component - preference_loss: Preference loss component - accuracy: Fraction of examples where chosen response has higher reward
- Return type:
Tuple[torch.Tensor, dict]
Initialization
- preference_loss(
- next_token_logits: torch.Tensor,
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.algorithms.loss_functions.DPOLossDataDict],
- global_valid_seqs: torch.Tensor,
- __call__(
- next_token_logits: torch.Tensor,
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.algorithms.loss_functions.DPOLossDataDict],
- global_valid_seqs: torch.Tensor,
- global_valid_toks: torch.Tensor | None,