nemo_rl.algorithms.loss.utils#

Module Contents#

Functions#

prepare_loss_input

Prepare loss input for a loss function.

_pack_input_ids

Pack input_ids from [B, S] to [1, T_packed // CP] using sequence boundaries.

prepare_packed_loss_input

Prepare loss input from packed logits in a single fused pass.

API#

nemo_rl.algorithms.loss.utils.prepare_loss_input(
logits: torch.Tensor,
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
loss_fn: nemo_rl.algorithms.loss.interfaces.LossFunction,
vocab_parallel_rank: Optional[int] = None,
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
sampling_params: Optional[nemo_rl.algorithms.logits_sampling_utils.TrainingSamplingParams] = None,
d2t: Optional[torch.Tensor] = None,
) tuple[dict[str, Any], nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any]]#

Prepare loss input for a loss function.

Parameters:
  • logits – Logits from the model.

  • data – Microbatch data. Will be updated if sampling_params is not None.

  • loss_fn – Loss function.

  • vocab_parallel_rank – Vocab parallel rank.

  • vocab_parallel_group – Vocab parallel group.

  • context_parallel_group – Context parallel group.

  • sampling_params – Sampling parameters.

  • d2t – Draft to target token mapping.

.. rubric:: Notes

vocab_parallel_rank, vocab_parallel_group, context_parallel_group are only used for megatron policy worker. sampling_params is only used for LossInputType.LOGPROB, and currently only supported for ClippedPGLossFn. d2t is only used for LossInputType.DRAFT.

Returns:

tuple(loss_input, maybe_updated_data)

nemo_rl.algorithms.loss.utils._pack_input_ids(
input_ids: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_q_padded: torch.Tensor,
cp_rank: int = 0,
cp_size: int = 1,
roll_shift: int = 0,
) torch.Tensor#

Pack input_ids from [B, S] to [1, T_packed // CP] using sequence boundaries.

Each sequence is individually padded to its padded length (from cu_seqlens_q_padded), optionally rolled, and CP-sharded at that padded length before being placed into the packed output. This matches how Megatron packs and CP-shards sequences in _pack_sequences_for_megatron.

Parameters:
  • input_ids – Unpacked input IDs [B, S].

  • cu_seqlens_q – Unpadded cumulative sequence lengths [B+1].

  • cu_seqlens_q_padded – Padded cumulative sequence lengths [B+1].

  • cp_rank – Context parallelism rank.

  • cp_size – Context parallelism size.

  • roll_shift – If non-zero, roll each padded sequence by this amount before CP-sharding. Use -1 to build shifted targets for next-token prediction.

nemo_rl.algorithms.loss.utils.prepare_packed_loss_input(
logits: torch.Tensor,
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
loss_fn: nemo_rl.algorithms.loss.interfaces.LossFunction,
cu_seqlens_q: torch.Tensor,
cu_seqlens_q_padded: torch.Tensor,
vocab_parallel_rank: Optional[int] = None,
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
sampling_params: Optional[nemo_rl.algorithms.logits_sampling_utils.TrainingSamplingParams] = None,
) tuple[dict[str, Any], nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any]]#

Prepare loss input from packed logits in a single fused pass.

Unlike prepare_loss_input which operates on a single (unpacked) sequence, this function computes log probabilities from packed logits across all sequences at once using from_parallel_logits_to_logprobs_packed_sequences.

Currently only supports LossInputType.LOGPROB.

Parameters:
  • logits – Packed logits from the model [1, T_packed // CP, V // TP].

  • data – Microbatch data (unpacked, [B, S]).

  • loss_fn – Loss function (must have input_type == LossInputType.LOGPROB).

  • cu_seqlens_q – Unpadded cumulative sequence lengths [B+1].

  • cu_seqlens_q_padded – Padded cumulative sequence lengths [B+1].

  • vocab_parallel_rank – Vocab parallel rank.

  • vocab_parallel_group – Vocab parallel group.

  • context_parallel_group – Context parallel group.

  • sampling_params – Sampling parameters.

Returns:

tuple(loss_input, maybe_updated_data)