nemo_rl.algorithms.loss.interfaces#
Module Contents#
Classes#
Signature for loss functions used in reinforcement learning algorithms. |
API#
- class nemo_rl.algorithms.loss.interfaces.LossType(*args, **kwds)#
Bases:
enum.Enum- TOKEN_LEVEL#
‘token_level’
- SEQUENCE_LEVEL#
‘sequence_level’
- class nemo_rl.algorithms.loss.interfaces.LossInputType(*args, **kwds)#
Bases:
enum.Enum- LOGIT#
‘logit’
- LOGPROB#
‘logprob’
- DISTILLATION#
‘distillation’
- DRAFT#
‘draft’
- class nemo_rl.algorithms.loss.interfaces.LossFunction#
Bases:
typing.ProtocolSignature for loss functions used in reinforcement learning algorithms.
Loss functions compute a scalar loss value and associated metrics from model logprobs and other data contained in a BatchedDataDict.
- loss_type: nemo_rl.algorithms.loss.interfaces.LossType#
None
- input_type: nemo_rl.algorithms.loss.interfaces.LossInputType#
None
- __call__(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict,
- global_valid_seqs: torch.Tensor,
- global_valid_toks: torch.Tensor,
- **kwargs: Any,
Compute loss and metrics from logprobs and other data.
- Parameters:
data – Dictionary containing all relevant data for loss computation such as rewards, values, actions, advantages, masks, and other algorithm-specific information needed for the particular loss calculation.
global_valid_seqs – torch.Tensor This tensor should contain the number of valid sequences in the microbatch. It’s used for global normalization for losses/metrics that are computed at the sequence level and needs to be aggregated across all microbatches.
global_valid_toks – torch.Tensor This tensor should contain the number of valid tokens in the microbatch. It’s used for global normalization for losses/metrics that are computed at the token level and needs to be aggregated across all microbatches.
**kwargs –
Loss function input, which varies by input_type:
For LossInputType.LOGPROB: next_token_logprobs (torch.Tensor)
For LossInputType.LOGIT: logits (torch.Tensor)
For LossInputType.DISTILLATION: student_topk_logprobs, teacher_topk_logprobs, H_all (torch.Tensor)
For LossInputType.DRAFT: teacher_logits, student_logits, mask (torch.Tensor)
- Returns:
(loss, metrics) - loss: A scalar tensor representing the loss value to be minimized during training - metrics: A dictionary of metrics related to the loss computation, which may include component losses, statistics about gradients/rewards, and other diagnostic information
- Return type:
tuple