nemo_rl.algorithms.loss.interfaces#

Module Contents#

Classes#

LossType

LossInputType

LossFunction

Signature for loss functions used in reinforcement learning algorithms.

API#

class nemo_rl.algorithms.loss.interfaces.LossType(*args, **kwds)#

Bases: enum.Enum

TOKEN_LEVEL#

‘token_level’

SEQUENCE_LEVEL#

‘sequence_level’

class nemo_rl.algorithms.loss.interfaces.LossInputType(*args, **kwds)#

Bases: enum.Enum

LOGIT#

‘logit’

LOGPROB#

‘logprob’

DISTILLATION#

‘distillation’

DRAFT#

‘draft’

class nemo_rl.algorithms.loss.interfaces.LossFunction#

Bases: typing.Protocol

Signature for loss functions used in reinforcement learning algorithms.

Loss functions compute a scalar loss value and associated metrics from model logprobs and other data contained in a BatchedDataDict.

loss_type: nemo_rl.algorithms.loss.interfaces.LossType#

None

input_type: nemo_rl.algorithms.loss.interfaces.LossInputType#

None

__call__(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict,
global_valid_seqs: torch.Tensor,
global_valid_toks: torch.Tensor,
**kwargs: Any,
) tuple[torch.Tensor, dict[str, Any]]#

Compute loss and metrics from logprobs and other data.

Parameters:
  • data – Dictionary containing all relevant data for loss computation such as rewards, values, actions, advantages, masks, and other algorithm-specific information needed for the particular loss calculation.

  • global_valid_seqs – torch.Tensor This tensor should contain the number of valid sequences in the microbatch. It’s used for global normalization for losses/metrics that are computed at the sequence level and needs to be aggregated across all microbatches.

  • global_valid_toks – torch.Tensor This tensor should contain the number of valid tokens in the microbatch. It’s used for global normalization for losses/metrics that are computed at the token level and needs to be aggregated across all microbatches.

  • **kwargs

    Loss function input, which varies by input_type:

    • For LossInputType.LOGPROB: next_token_logprobs (torch.Tensor)

    • For LossInputType.LOGIT: logits (torch.Tensor)

    • For LossInputType.DISTILLATION: student_topk_logprobs, teacher_topk_logprobs, H_all (torch.Tensor)

    • For LossInputType.DRAFT: teacher_logits, student_logits, mask (torch.Tensor)

Returns:

(loss, metrics) - loss: A scalar tensor representing the loss value to be minimized during training - metrics: A dictionary of metrics related to the loss computation, which may include component losses, statistics about gradients/rewards, and other diagnostic information

Return type:

tuple