nemo_rl.algorithms.interfaces#

Module Contents#

Classes#

LossFunction

Signature for loss functions used in reinforcement learning algorithms.

API#

class nemo_rl.algorithms.interfaces.LossFunction[source]#

Bases: typing.Protocol

Signature for loss functions used in reinforcement learning algorithms.

Loss functions compute a scalar loss value and associated metrics from model logprobs and other data contained in a BatchedDataDict.

__call__(
next_token_logits: torch.Tensor,
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict,
global_valid_seqs: torch.Tensor,
global_valid_toks: torch.Tensor,
) Tuple[torch.Tensor, Dict[str, Any]][source]#

Compute loss and metrics from logprobs and other data.

Parameters:
  • next_token_logits – Logits from the model, typically with shape [batch_size, seq_len, vocab_size]. For each position (b, i), contains the logit distribution over the entire vocabulary for predicting the next token (at position i+1). For example, if processing β€œThe cat sat on”, then next_token_logits[b, 3] would contain the logits for predicting the word that follows β€œon”.

  • data – Dictionary containing all relevant data for loss computation such as rewards, values, actions, advantages, masks, and other algorithm-specific information needed for the particular loss calculation.

  • global_valid_seqs – torch.Tensor this tensor should contain the number of valid sequences in the microbatch. It’s used for global normalization for losses/metrics that are computed at the sequence level and needs to be aggregated across all microbatches.

  • global_valid_toks – torch.Tensor This tensor should contain the number of valid tokens in the microbatch. It’s used for global normalization for losses/metrics that are computed at the token level and needs to be aggregated across all microbatches.

Returns:

(loss, metrics) - loss: A scalar tensor representing the loss value to be minimized during training - metrics: A dictionary of metrics related to the loss computation, which may include component losses, statistics about gradients/rewards, and other diagnostic information

Return type:

tuple