# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
from typing import Any, Tuple, TypedDict
import torch
from nemo_rl.algorithms.interfaces import LossFunction
from nemo_rl.algorithms.utils import (
calculate_kl_penalty_joschu2020,
masked_mean,
)
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
from nemo_rl.models.dtensor.parallelize import (
get_logprobs_from_vocab_parallel_logits,
)
[docs]
class LossType(enum.Enum):
TOKEN_LEVEL = "token_level"
SEQUENCE_LEVEL = "sequence_level"
[docs]
class ClippedPGLossConfig(TypedDict):
reference_policy_kl_penalty: float
ratio_clip_min: float
ratio_clip_max: float
ratio_clip_c: float
use_on_policy_kl_approximation: bool
use_importance_sampling_correction: bool
token_level_loss: bool
[docs]
class ClippedPGLossDataDict(TypedDict):
"""Required keys for the Clipped Policy Gradient loss function."""
input_ids: torch.Tensor
advantages: torch.Tensor
prev_logprobs: torch.Tensor
generation_logprobs: torch.Tensor
reference_policy_logprobs: torch.Tensor
token_mask: torch.Tensor
sample_mask: torch.Tensor
__extra__: Any
[docs]
class ClippedPGLossFn(LossFunction):
"""Generalized Clipped Policy Gradient loss function w/ KL regularization.
This implements:
- PPO (Clipped) - https://arxiv.org/abs/1707.06347
- GRPO - https://arxiv.org/abs/2402.03300
- REINFORCE/RLOO (set disable_ppo_ratio = True and ignores ratio_clip_min/ratio_clip_max) - https://arxiv.org/abs/2402.14740
Formula:
L(θ) = E_t [ min(r_t(θ) * A_t, clip(r_t(θ), 1-ε, 1+ε) * A_t) ] - β * KL(π_θ || π_ref)
where:
- r_t(θ) = π_θ(a_t|s_t) / π_θ_old(a_t|s_t) is the probability ratio
- A_t is the advantage estimate
- ε is the clip parameter (ratio_clip_min/ratio_clip_max)
- As proposed in the DAPO paper (https://arxiv.org/pdf/2503.14476),
we allow setting a distinct minimum and maximum value for the clip parameter (set to the same value for PPO/GRPO/etc.)
- ratio_clip_min: minimum value for the clip parameter
- ratio_clip_max: maximum value for the clip parameter
- β is the KL penalty coefficient (reference_policy_kl_penalty)
- KL(π_θ || π_ref) is the KL divergence between the current policy and reference policy (Schulman Approx.)
For REINFORCE/RLOO (when disable_ppo_ratio=True), the formula simplifies to:
L(θ) = E_t [ π_θ(a_t|s_t) * A_t ] - β * KL(π_θ || π_ref)
Also supports "Dual-Clipping" from https://arxiv.org/pdf/1912.09729, which
imposes an additional upper bound on the probability ratio when advantages are negative.
This prevents excessive policy updates. $rA << 0$ -> $cA$(clipped)
The loss function is modified to the following when A_t < 0:
L(θ) = E_t [ max(min(r_t(θ) * A_t, clip(r_t(θ), 1-ε, 1+ε) * A_t), c * A_t) ] - β * KL(π_θ || π_ref)
where:
- c is the dual-clip parameter (ratio_clip_c), which must be greater than 1 and is
usually set as 3 empirically.
Due to potential numerical instability, we cast the logits to float32 before computing the loss.
"""
def __init__(self, cfg: ClippedPGLossConfig):
self.ratio_clip_min = cfg["ratio_clip_min"]
self.ratio_clip_max = cfg["ratio_clip_max"]
self.ratio_clip_c = cfg["ratio_clip_c"] # set to None to disable dual-clipping
self.reference_policy_kl_penalty = cfg["reference_policy_kl_penalty"]
self.disable_ppo_ratio = cfg.get("disable_ppo_ratio", False)
self.use_on_policy_kl_approximation = cfg["use_on_policy_kl_approximation"]
self.use_importance_sampling_correction = cfg[
"use_importance_sampling_correction"
]
self.loss_type = (
LossType.TOKEN_LEVEL if cfg["token_level_loss"] else LossType.SEQUENCE_LEVEL
)
[docs]
def __call__(
self,
next_token_logits: torch.Tensor,
data: BatchedDataDict[ClippedPGLossDataDict],
global_valid_seqs: torch.Tensor,
global_valid_toks: torch.Tensor,
) -> Tuple[torch.Tensor, dict]:
"""Clipped Policy Gradient RL loss function."""
token_mask = data["token_mask"][:, 1:]
sample_mask = data["sample_mask"]
advantages = data["advantages"][:, 1:]
prev_logprobs = data["prev_logprobs"][:, 1:]
generation_logprobs = data["generation_logprobs"][:, 1:]
reference_policy_logprobs = data["reference_policy_logprobs"][:, 1:]
mask = token_mask * sample_mask.unsqueeze(-1)
# token_mult_prob_error
# See more details and other metrics in docs/guides/grpo.md#metrics
lp_error = torch.abs(generation_logprobs - prev_logprobs) # noqa: F841 (precommit ignore for now)
# average over all tokens in the microbatch
mult_prob_error = masked_mean(
torch.exp(lp_error * mask),
mask,
global_normalization_factor=global_valid_toks,
).item()
next_token_logits = next_token_logits.to(torch.float32)
if isinstance(next_token_logits, torch.distributed.tensor.DTensor):
curr_logprobs = get_logprobs_from_vocab_parallel_logits(
next_token_logits, data["input_ids"]
)
else:
next_token_logits_wo_last = next_token_logits[
:, :-1
] # Remove last position's logits
next_token_logprobs = torch.nn.functional.log_softmax(
next_token_logits_wo_last, dim=-1
)
next_tokens = data.get("input_ids")[:, 1:].cuda() # Skip first token
curr_logprobs = next_token_logprobs.gather(
dim=-1, index=next_tokens.unsqueeze(-1)
).squeeze(-1)
# Calculate KL regularization.
if self.reference_policy_kl_penalty != 0:
if self.use_on_policy_kl_approximation:
# See: docs/guides/grpo.md#on-policy-kl-approximation
kl_importance_weights = torch.exp(
curr_logprobs - generation_logprobs
).detach()
kl_importance_weights = torch.nan_to_num(
kl_importance_weights, nan=0.0, posinf=0.0, neginf=0.0
)
else:
kl_importance_weights = torch.ones_like(curr_logprobs)
kl = (
kl_importance_weights
* self.reference_policy_kl_penalty
* calculate_kl_penalty_joschu2020(
logprobs_policy=curr_logprobs,
logprobs_reference=reference_policy_logprobs,
)
)
if self.loss_type == LossType.TOKEN_LEVEL:
kl = masked_mean(
kl, mask, global_normalization_factor=global_valid_toks
)
else:
kl = masked_mean(
masked_mean(kl, token_mask, dim=-1),
sample_mask,
global_normalization_factor=global_valid_seqs,
)
else:
kl = 0
# Calculate clipped loss function if ppo ratio is enabled.
if not self.disable_ppo_ratio:
ratios = (curr_logprobs - prev_logprobs).exp()
ratios_clamped = ratios.clamp(
1.0 - self.ratio_clip_min, 1.0 + self.ratio_clip_max
)
else:
ratios = curr_logprobs
ratios_clamped = curr_logprobs
loss1 = -advantages * ratios
loss2 = -advantages * ratios_clamped
# Determine which value to use for clipping (max for pessimistic estimate)
clip_loss = torch.max(loss1, loss2)
# Dual-clipping see https://arxiv.org/pdf/1912.09729
if self.ratio_clip_c is not None:
assert self.ratio_clip_c > 1, (
f"ratio_clip_c must exceed 1 representing a lower bound of the ratios, got {self.ratio_clip_c}."
)
loss3 = -advantages * self.ratio_clip_c
clip_loss = torch.where(
advantages < 0, torch.min(clip_loss, loss3), clip_loss
)
# See: docs/guides/grpo.md#importance-sampling-correction
actor_importance_weights = torch.exp(prev_logprobs - generation_logprobs)
actor_importance_weights = torch.nan_to_num(
actor_importance_weights, nan=0.0, posinf=0.0, neginf=0.0
)
if self.use_importance_sampling_correction:
importance_weights_to_use = actor_importance_weights
else:
importance_weights_to_use = torch.ones_like(prev_logprobs)
if self.loss_type == LossType.TOKEN_LEVEL:
actor_loss = masked_mean(
importance_weights_to_use * clip_loss,
mask,
global_normalization_factor=global_valid_toks,
)
else:
actor_loss = masked_mean(
masked_mean(
importance_weights_to_use * clip_loss,
token_mask,
dim=-1,
),
sample_mask,
global_normalization_factor=global_valid_seqs,
)
# See: docs/guides/grpo.md#sampling-importance-ratio
sample_importance_ratio = masked_mean(
actor_importance_weights,
mask,
global_normalization_factor=global_valid_toks,
)
# Approximating entropy as E_{s ~ \pi_{gen}(s)}[-(\pi_{curr}/\pi_{gen})log(\pi_{curr}(s))]
# See more details and other metrics in docs/guides/grpo.md#metrics
with torch.no_grad():
seq_entropy_approx = -masked_mean(
torch.exp(curr_logprobs - generation_logprobs) * curr_logprobs,
mask,
global_normalization_factor=global_valid_toks,
)
loss = actor_loss + kl
with torch.no_grad():
probs_ratio = masked_mean(
ratios.detach(),
mask,
global_normalization_factor=global_valid_toks,
).item()
probs_ratio_clamped = masked_mean(
ratios_clamped.detach(),
mask,
global_normalization_factor=global_valid_toks,
).item()
# If you provided a global_valid_{seqs/toks}, all metrics here are globally normalized
# by either sequence or token count, depending on particular metric.
# To get the true metric, you'll need to sum over the microbatch.
return (
loss,
{
"loss": loss.item(),
"probs_ratio": probs_ratio,
"probs_ratio_clamped": probs_ratio_clamped,
"kl_penalty": kl.item() / self.reference_policy_kl_penalty if kl else 0,
"token_mult_prob_error": mult_prob_error,
"sampling_importance_ratio": sample_importance_ratio.item(),
"num_valid_samples": sample_mask.sum().item(),
"approx_entropy": seq_entropy_approx.item(),
},
)
[docs]
class NLLLoss(LossFunction):
"""Negative Log Likelihood Loss function."""
loss_type = LossType.TOKEN_LEVEL
[docs]
def __call__(
self,
next_token_logits: torch.Tensor,
data: BatchedDataDict,
global_valid_seqs: torch.Tensor | None,
global_valid_toks: torch.Tensor,
dpo_loss: bool = False,
dpo_average_log_probs: bool = False,
) -> Tuple[torch.Tensor, dict]:
# logits shape: [batch_size, seq_len, vocab_size]
# Get the next token logits for each position
token_mask = data["token_mask"][:, 1:]
sample_mask = data["sample_mask"]
mask = token_mask * sample_mask.unsqueeze(-1)
next_token_logits = next_token_logits.to(torch.float32)
# Gather the logprobs for the actual next tokens
if isinstance(next_token_logits, torch.distributed.tensor.DTensor):
token_logprobs = get_logprobs_from_vocab_parallel_logits(
next_token_logits, data["input_ids"]
)
else:
next_tokens = data.get("input_ids")[:, 1:].cuda() # Skip first token
next_token_logprobs = torch.nn.functional.log_softmax(
next_token_logits, dim=-1
)
logprobs = next_token_logprobs[:, :-1] # Remove last position's logits
token_logprobs = logprobs.gather(
dim=-1, index=next_tokens.unsqueeze(-1)
).squeeze(-1)
if dpo_loss:
## shape: [batch_size]
num_unmasked_tokens = torch.sum(mask, -1)
## multiply by sample_mask to zero out invalid samples
loss = -torch.sum(token_logprobs * mask, dim=-1)
if dpo_average_log_probs:
loss = loss / num_unmasked_tokens.clamp(min=1)
else:
## single scalar loss
## scale by the total number of tokens in the batch
loss = -masked_mean(
token_logprobs,
mask,
global_normalization_factor=global_valid_toks,
)
return loss, {
"loss": loss.item() if loss.ndim == 0 else loss,
"num_unmasked_tokens": mask.sum().item(),
"num_valid_samples": sample_mask.sum().item(),
}
[docs]
class DPOLossConfig(TypedDict):
reference_policy_kl_penalty: float
preference_loss_weight: float = 1.0
sft_loss_weight: float = 0.0
preference_average_log_probs: bool = False
sft_average_log_probs: bool = False
[docs]
class DPOLossDataDict(TypedDict):
"""Required keys for the DPO loss function."""
input_ids: torch.Tensor
reference_policy_logprobs: torch.Tensor
token_mask: torch.Tensor
sample_mask: torch.Tensor
[docs]
class DPOLossFn(LossFunction):
"""Direct Preference Optimization (DPO) loss function.
This loss function implements the DPO algorithm as described in:
"Direct Preference Optimization: Your Language Model is Secretly a Reward Model"
(https://arxiv.org/abs/2305.18290)
The loss combines two main components:
1. Preference Loss: Optimizes the model to prefer chosen responses over rejected ones
2. SFT Loss (optional): Auxiliary supervised fine-tuning loss on chosen responses
The total loss is computed as:
L(θ) = w_p * L_pref(θ) + w_s * L_sft(θ)
where:
- w_p is the preference_loss_weight
- w_s is the sft_loss_weight
- L_pref(θ) is the preference loss term
- L_sft(θ) is the supervised fine-tuning loss term
The preference loss term is computed as:
L_pref(θ) = -E[log(σ(β * (r_chosen - r_rejected)))]
where:
- σ is the sigmoid function
- β is the reference_policy_kl_penalty
- r_chosen and r_rejected are the rewards for chosen and rejected responses
- The rewards are computed as the sum of log probability differences between
the current policy and reference policy
If preference_average_log_probs is True, the rewards are averaged over tokens:
r = (1/n) * Σ_t (log π_θ(a_t|s_t) - log π_ref(a_t|s_t))
Otherwise, the rewards are summed over tokens.
The SFT loss term is a standard negative log likelihood loss on the chosen responses.
If sft_average_log_probs is True, the loss is averaged over tokens.
Args:
cfg (DPOLossConfig): Configuration dictionary containing:
- reference_policy_kl_penalty (float): Strength of the KL penalty term (β)
- preference_loss_weight (float): Weight for the preference loss term (w_p)
- sft_loss_weight (float): Weight for the SFT loss term (w_s)
- preference_average_log_probs (bool): Whether to average log probs across tokens in preference loss
- sft_average_log_probs (bool): Whether to average log probs across tokens in SFT loss
Returns:
Tuple[torch.Tensor, dict]: A tuple containing:
- The total loss value
- A dictionary with metrics including:
- loss: Total loss value
- sft_loss: SFT loss component
- preference_loss: Preference loss component
- accuracy: Fraction of examples where chosen response has higher reward
"""
def __init__(self, cfg: DPOLossConfig):
self.reference_policy_kl_penalty = cfg["reference_policy_kl_penalty"]
self.preference_loss_weight = cfg["preference_loss_weight"]
self.sft_loss_weight = cfg["sft_loss_weight"]
self.preference_average_log_probs = cfg["preference_average_log_probs"]
self.sft_average_log_probs = cfg["sft_average_log_probs"]
self.sft_loss = NLLLoss()
self.loss_type = LossType.SEQUENCE_LEVEL
[docs]
def split_output_tensor(self, tensor: torch.Tensor):
return tensor[::2], tensor[1::2]
[docs]
def preference_loss(
self,
next_token_logits: torch.Tensor,
data: BatchedDataDict[DPOLossDataDict],
global_valid_seqs: torch.Tensor,
) -> torch.Tensor:
## TODO(@ashors): there's some duplicate code here with the NLLLoss function. We should refactor
token_mask = data["token_mask"][:, 1:]
sample_mask = data["sample_mask"]
next_token_logits = next_token_logits.to(torch.float32)
if isinstance(next_token_logits, torch.distributed.tensor.DTensor):
token_logprobs = get_logprobs_from_vocab_parallel_logits(
next_token_logits, data["input_ids"]
)
else:
next_tokens = data.get("input_ids")[:, 1:].cuda() # Skip first token
next_token_logprobs = torch.nn.functional.log_softmax(
next_token_logits, dim=-1
)
logprobs = next_token_logprobs[:, :-1] # Remove last position's logits
token_logprobs = logprobs.gather(
dim=-1, index=next_tokens.unsqueeze(-1)
).squeeze(-1)
ref_logprobs = data["reference_policy_logprobs"][:, :-1]
diff = (token_logprobs - ref_logprobs) * token_mask
rewards = diff.sum(-1)
if self.preference_average_log_probs:
rewards = rewards / token_mask.sum(-1).clamp(min=1)
rewards_chosen, rewards_rejected = self.split_output_tensor(rewards)
rewards_delta = rewards_chosen - rewards_rejected
per_sample_loss = (
-torch.nn.functional.logsigmoid(
self.reference_policy_kl_penalty * rewards_delta
)
* sample_mask[::2]
) ## zero out invalid samples
## divide by 2 because each preference example corresponds to 2 samples (chosen, rejected)
return (
masked_mean(
per_sample_loss,
sample_mask[::2],
global_normalization_factor=global_valid_seqs / 2,
),
masked_mean(
rewards_chosen > rewards_rejected,
sample_mask[::2],
global_normalization_factor=global_valid_seqs / 2,
),
masked_mean(
rewards_chosen,
sample_mask[::2],
global_normalization_factor=global_valid_seqs / 2,
),
masked_mean(
rewards_rejected,
sample_mask[1::2],
global_normalization_factor=global_valid_seqs / 2,
),
)
[docs]
def __call__(
self,
next_token_logits: torch.Tensor,
data: BatchedDataDict[DPOLossDataDict],
global_valid_seqs: torch.Tensor,
global_valid_toks: torch.Tensor | None,
) -> Tuple[torch.Tensor, dict]:
sft_loss_chosen = torch.tensor(0.0)
if self.sft_loss_weight > 0:
assert global_valid_toks is not None, (
"global_valid_toks must be provided for SFT loss"
)
sft_loss, _ = self.sft_loss(
next_token_logits,
data,
global_valid_seqs=global_valid_seqs,
global_valid_toks=global_valid_toks, ## unused because sft loss returned is at the sample level
dpo_loss=True,
dpo_average_log_probs=self.sft_average_log_probs,
)
sft_loss_chosen, sft_loss_rejected = self.split_output_tensor(sft_loss)
sft_loss_chosen = masked_mean(
sft_loss_chosen,
data["sample_mask"][::2],
global_normalization_factor=global_valid_seqs / 2,
)
(
preference_loss,
accuracy,
rewards_chosen_mean,
rewards_rejected_mean,
) = self.preference_loss(next_token_logits, data, global_valid_seqs)
dpo_loss = (
self.sft_loss_weight * sft_loss_chosen
+ self.preference_loss_weight * preference_loss
)
## divide by 2 because we're summing over (chosen, rejected) pairs
num_valid_samples = data["sample_mask"].sum() / 2
return dpo_loss, {
"loss": dpo_loss.item(),
"sft_loss": sft_loss_chosen.item(),
"preference_loss": preference_loss.item(),
"accuracy": accuracy.item(),
"rewards_chosen_mean": rewards_chosen_mean.item(),
"rewards_rejected_mean": rewards_rejected_mean.item(),
"num_valid_samples": num_valid_samples.item(),
}