Source code for nemo_rl.algorithms.interfaces

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Protocol, Tuple

import torch

from nemo_rl.distributed.batched_data_dict import BatchedDataDict


[docs] class LossFunction(Protocol): """Signature for loss functions used in reinforcement learning algorithms. Loss functions compute a scalar loss value and associated metrics from model logprobs and other data contained in a BatchedDataDict. """
[docs] def __call__( self, next_token_logits: torch.Tensor, data: BatchedDataDict, global_valid_seqs: torch.Tensor, global_valid_toks: torch.Tensor, ) -> Tuple[torch.Tensor, Dict[str, Any]]: """Compute loss and metrics from logprobs and other data. Args: next_token_logits: Logits from the model, typically with shape [batch_size, seq_len, vocab_size]. For each position (b, i), contains the logit distribution over the entire vocabulary for predicting the next token (at position i+1). For example, if processing "The cat sat on", then next_token_logits[b, 3] would contain the logits for predicting the word that follows "on". data: Dictionary containing all relevant data for loss computation such as rewards, values, actions, advantages, masks, and other algorithm-specific information needed for the particular loss calculation. global_valid_seqs: torch.Tensor this tensor should contain the number of valid sequences in the microbatch. It's used for global normalization for losses/metrics that are computed at the sequence level and needs to be aggregated across all microbatches. global_valid_toks: torch.Tensor This tensor should contain the number of valid tokens in the microbatch. It's used for global normalization for losses/metrics that are computed at the token level and needs to be aggregated across all microbatches. Returns: tuple: (loss, metrics) - loss: A scalar tensor representing the loss value to be minimized during training - metrics: A dictionary of metrics related to the loss computation, which may include component losses, statistics about gradients/rewards, and other diagnostic information """ pass