nemo_rl.algorithms.interfaces
#
Module Contents#
Classes#
Signature for loss functions used in reinforcement learning algorithms. |
API#
- class nemo_rl.algorithms.interfaces.LossFunction[source]#
Bases:
typing.Protocol
Signature for loss functions used in reinforcement learning algorithms.
Loss functions compute a scalar loss value and associated metrics from model logprobs and other data contained in a BatchedDataDict.
- __call__(
- next_token_logits: torch.Tensor,
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict,
- global_valid_seqs: torch.Tensor,
- global_valid_toks: torch.Tensor,
Compute loss and metrics from logprobs and other data.
- Parameters:
next_token_logits β Logits from the model, typically with shape [batch_size, seq_len, vocab_size]. For each position (b, i), contains the logit distribution over the entire vocabulary for predicting the next token (at position i+1). For example, if processing βThe cat sat onβ, then next_token_logits[b, 3] would contain the logits for predicting the word that follows βonβ.
data β Dictionary containing all relevant data for loss computation such as rewards, values, actions, advantages, masks, and other algorithm-specific information needed for the particular loss calculation.
global_valid_seqs β torch.Tensor this tensor should contain the number of valid sequences in the microbatch. Itβs used for global normalization for losses/metrics that are computed at the sequence level and needs to be aggregated across all microbatches.
global_valid_toks β torch.Tensor This tensor should contain the number of valid tokens in the microbatch. Itβs used for global normalization for losses/metrics that are computed at the token level and needs to be aggregated across all microbatches.
- Returns:
(loss, metrics) - loss: A scalar tensor representing the loss value to be minimized during training - metrics: A dictionary of metrics related to the loss computation, which may include component losses, statistics about gradients/rewards, and other diagnostic information
- Return type:
tuple