nemo_rl.algorithms.loss.wrapper#
Module Contents#
Classes#
Fused sequence packing loss wrapper that processes all sequences in one forward pass. |
|
Combine policy loss with draft soft cross-entropy loss. |
Functions#
Wraps a loss function to handle input preparation for megatron policy worker. |
Data#
API#
- nemo_rl.algorithms.loss.wrapper.Tensor#
‘TypeVar(…)’
- class nemo_rl.algorithms.loss.wrapper.SequencePackingLossWrapper(
- loss_fn: nemo_rl.algorithms.loss.interfaces.LossFunction,
- prepare_fn: Callable[Any, Any],
- cu_seqlens_q: nemo_rl.algorithms.loss.wrapper.Tensor,
- cu_seqlens_q_padded: Optional[nemo_rl.algorithms.loss.wrapper.Tensor] = None,
- vocab_parallel_rank: Optional[int] = None,
- vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
- context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
Initialization
Wrap a loss function to handle sequence packing.
- Parameters:
loss_fn – Loss function.
prepare_fn – Prepare function.
cu_seqlens_q – Unpadded cu seqlens q.
cu_seqlens_q_padded – Padded cu seqlens q.
vocab_parallel_rank – Vocab parallel rank.
vocab_parallel_group – Vocab parallel group.
context_parallel_group – Context parallel group.
vocab_parallel_rank
vocab_parallel_group
worker. (context_parallel_group are only used for megatron policy)
- Returns:
Sequence packing loss wrapper.
- __call__(
- next_token_logits: nemo_rl.algorithms.loss.wrapper.Tensor,
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
- global_valid_seqs: nemo_rl.algorithms.loss.wrapper.Tensor | None,
- global_valid_toks: nemo_rl.algorithms.loss.wrapper.Tensor | None,
Wraps a loss function to handle sequence packing by doing one sequence at a time to avoid excessive padding.
- class nemo_rl.algorithms.loss.wrapper.SequencePackingFusionLossWrapper(
- loss_fn: nemo_rl.algorithms.loss.interfaces.LossFunction,
- prepare_fn: Callable[..., Any],
- cu_seqlens_q: nemo_rl.algorithms.loss.wrapper.Tensor,
- cu_seqlens_q_padded: Optional[nemo_rl.algorithms.loss.wrapper.Tensor] = None,
- vocab_parallel_rank: Optional[int] = None,
- vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
- context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
Fused sequence packing loss wrapper that processes all sequences in one forward pass.
Unlike SequencePackingLossWrapper which iterates over sequences one at a time, this wrapper calls prepare_fn once on the packed logits to compute log probabilities in a single shot, then calls the loss function once with the pre-computed result.
This avoids per-sequence kernel launches and TP/CP communication overhead while producing numerically identical results.
The prepare_fn should be prepare_packed_loss_input (from nemo_rl.algorithms.loss.utils), which currently only supports LossInputType.LOGPROB.
Initialization
- __call__(
- next_token_logits: nemo_rl.algorithms.loss.wrapper.Tensor,
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
- global_valid_seqs: nemo_rl.algorithms.loss.wrapper.Tensor | None,
- global_valid_toks: nemo_rl.algorithms.loss.wrapper.Tensor | None,
Compute loss for all packed sequences in one forward pass.
- class nemo_rl.algorithms.loss.wrapper.DraftLossWrapper(
- loss_fn: Callable[..., tuple[torch.Tensor, dict[str, Any]]],
- prepare_fn: Callable[Any, Any],
- data_dict: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
- loss_weight: float = 1.0,
- vocab_parallel_rank: Optional[int] = None,
- vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
- context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
Combine policy loss with draft soft cross-entropy loss.
Initialization
- __call__(
- next_token_logits: torch.Tensor,
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
- global_valid_seqs: torch.Tensor | None,
- global_valid_toks: torch.Tensor | None,
- **kwargs: Any,
- nemo_rl.algorithms.loss.wrapper.wrap_loss_fn_with_input_preparation(
- next_token_logits: nemo_rl.algorithms.loss.wrapper.Tensor,
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
- global_valid_seqs: nemo_rl.algorithms.loss.wrapper.Tensor | None,
- global_valid_toks: nemo_rl.algorithms.loss.wrapper.Tensor | None,
- loss_fn: nemo_rl.algorithms.loss.interfaces.LossFunction,
- prepare_fn: Callable[Any, Any],
- vocab_parallel_rank: Optional[int] = None,
- vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
- context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
Wraps a loss function to handle input preparation for megatron policy worker.