nemo_rl.algorithms.loss.wrapper#

Module Contents#

Classes#

SequencePackingLossWrapper

SequencePackingFusionLossWrapper

Fused sequence packing loss wrapper that processes all sequences in one forward pass.

DraftLossWrapper

Combine policy loss with draft soft cross-entropy loss.

Functions#

wrap_loss_fn_with_input_preparation

Wraps a loss function to handle input preparation for megatron policy worker.

Data#

API#

nemo_rl.algorithms.loss.wrapper.Tensor#

‘TypeVar(…)’

class nemo_rl.algorithms.loss.wrapper.SequencePackingLossWrapper(
loss_fn: nemo_rl.algorithms.loss.interfaces.LossFunction,
prepare_fn: Callable[Any, Any],
cu_seqlens_q: nemo_rl.algorithms.loss.wrapper.Tensor,
cu_seqlens_q_padded: Optional[nemo_rl.algorithms.loss.wrapper.Tensor] = None,
vocab_parallel_rank: Optional[int] = None,
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
)#

Initialization

Wrap a loss function to handle sequence packing.

Parameters:
  • loss_fn – Loss function.

  • prepare_fn – Prepare function.

  • cu_seqlens_q – Unpadded cu seqlens q.

  • cu_seqlens_q_padded – Padded cu seqlens q.

  • vocab_parallel_rank – Vocab parallel rank.

  • vocab_parallel_group – Vocab parallel group.

  • context_parallel_group – Context parallel group.

  • vocab_parallel_rank

  • vocab_parallel_group

  • worker. (context_parallel_group are only used for megatron policy)

Returns:

Sequence packing loss wrapper.

__call__(
next_token_logits: nemo_rl.algorithms.loss.wrapper.Tensor,
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
global_valid_seqs: nemo_rl.algorithms.loss.wrapper.Tensor | None,
global_valid_toks: nemo_rl.algorithms.loss.wrapper.Tensor | None,
) tuple[nemo_rl.algorithms.loss.wrapper.Tensor, dict[str, Any]]#

Wraps a loss function to handle sequence packing by doing one sequence at a time to avoid excessive padding.

class nemo_rl.algorithms.loss.wrapper.SequencePackingFusionLossWrapper(
loss_fn: nemo_rl.algorithms.loss.interfaces.LossFunction,
prepare_fn: Callable[..., Any],
cu_seqlens_q: nemo_rl.algorithms.loss.wrapper.Tensor,
cu_seqlens_q_padded: Optional[nemo_rl.algorithms.loss.wrapper.Tensor] = None,
vocab_parallel_rank: Optional[int] = None,
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
)#

Fused sequence packing loss wrapper that processes all sequences in one forward pass.

Unlike SequencePackingLossWrapper which iterates over sequences one at a time, this wrapper calls prepare_fn once on the packed logits to compute log probabilities in a single shot, then calls the loss function once with the pre-computed result.

This avoids per-sequence kernel launches and TP/CP communication overhead while producing numerically identical results.

The prepare_fn should be prepare_packed_loss_input (from nemo_rl.algorithms.loss.utils), which currently only supports LossInputType.LOGPROB.

Initialization

__call__(
next_token_logits: nemo_rl.algorithms.loss.wrapper.Tensor,
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
global_valid_seqs: nemo_rl.algorithms.loss.wrapper.Tensor | None,
global_valid_toks: nemo_rl.algorithms.loss.wrapper.Tensor | None,
) tuple[nemo_rl.algorithms.loss.wrapper.Tensor, dict[str, Any]]#

Compute loss for all packed sequences in one forward pass.

class nemo_rl.algorithms.loss.wrapper.DraftLossWrapper(
loss_fn: Callable[..., tuple[torch.Tensor, dict[str, Any]]],
prepare_fn: Callable[Any, Any],
data_dict: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
loss_weight: float = 1.0,
vocab_parallel_rank: Optional[int] = None,
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
)#

Combine policy loss with draft soft cross-entropy loss.

Initialization

__call__(
next_token_logits: torch.Tensor,
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
global_valid_seqs: torch.Tensor | None,
global_valid_toks: torch.Tensor | None,
**kwargs: Any,
) tuple[torch.Tensor, dict[str, Any]]#
nemo_rl.algorithms.loss.wrapper.wrap_loss_fn_with_input_preparation(
next_token_logits: nemo_rl.algorithms.loss.wrapper.Tensor,
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
global_valid_seqs: nemo_rl.algorithms.loss.wrapper.Tensor | None,
global_valid_toks: nemo_rl.algorithms.loss.wrapper.Tensor | None,
loss_fn: nemo_rl.algorithms.loss.interfaces.LossFunction,
prepare_fn: Callable[Any, Any],
vocab_parallel_rank: Optional[int] = None,
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
) tuple[nemo_rl.algorithms.loss.wrapper.Tensor, dict[str, Any]]#

Wraps a loss function to handle input preparation for megatron policy worker.