nemo_rl.algorithms.loss.utils#
Module Contents#
Functions#
Prepare loss input for a loss function. |
|
Pack input_ids from [B, S] to [1, T_packed // CP] using sequence boundaries. |
|
Prepare loss input from packed logits in a single fused pass. |
API#
- nemo_rl.algorithms.loss.utils.prepare_loss_input(
- logits: torch.Tensor,
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
- loss_fn: nemo_rl.algorithms.loss.interfaces.LossFunction,
- vocab_parallel_rank: Optional[int] = None,
- vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
- context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
- sampling_params: Optional[nemo_rl.algorithms.logits_sampling_utils.TrainingSamplingParams] = None,
- d2t: Optional[torch.Tensor] = None,
Prepare loss input for a loss function.
- Parameters:
logits – Logits from the model.
data – Microbatch data. Will be updated if sampling_params is not None.
loss_fn – Loss function.
vocab_parallel_rank – Vocab parallel rank.
vocab_parallel_group – Vocab parallel group.
context_parallel_group – Context parallel group.
sampling_params – Sampling parameters.
d2t – Draft to target token mapping.
.. rubric:: Notes
vocab_parallel_rank, vocab_parallel_group, context_parallel_group are only used for megatron policy worker. sampling_params is only used for LossInputType.LOGPROB, and currently only supported for ClippedPGLossFn. d2t is only used for LossInputType.DRAFT.
- Returns:
tuple(loss_input, maybe_updated_data)
- nemo_rl.algorithms.loss.utils._pack_input_ids(
- input_ids: torch.Tensor,
- cu_seqlens_q: torch.Tensor,
- cu_seqlens_q_padded: torch.Tensor,
- cp_rank: int = 0,
- cp_size: int = 1,
- roll_shift: int = 0,
Pack input_ids from [B, S] to [1, T_packed // CP] using sequence boundaries.
Each sequence is individually padded to its padded length (from cu_seqlens_q_padded), optionally rolled, and CP-sharded at that padded length before being placed into the packed output. This matches how Megatron packs and CP-shards sequences in _pack_sequences_for_megatron.
- Parameters:
input_ids – Unpacked input IDs [B, S].
cu_seqlens_q – Unpadded cumulative sequence lengths [B+1].
cu_seqlens_q_padded – Padded cumulative sequence lengths [B+1].
cp_rank – Context parallelism rank.
cp_size – Context parallelism size.
roll_shift – If non-zero, roll each padded sequence by this amount before CP-sharding. Use -1 to build shifted targets for next-token prediction.
- nemo_rl.algorithms.loss.utils.prepare_packed_loss_input(
- logits: torch.Tensor,
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
- loss_fn: nemo_rl.algorithms.loss.interfaces.LossFunction,
- cu_seqlens_q: torch.Tensor,
- cu_seqlens_q_padded: torch.Tensor,
- vocab_parallel_rank: Optional[int] = None,
- vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
- context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
- sampling_params: Optional[nemo_rl.algorithms.logits_sampling_utils.TrainingSamplingParams] = None,
Prepare loss input from packed logits in a single fused pass.
Unlike prepare_loss_input which operates on a single (unpacked) sequence, this function computes log probabilities from packed logits across all sequences at once using from_parallel_logits_to_logprobs_packed_sequences.
Currently only supports LossInputType.LOGPROB.
- Parameters:
logits – Packed logits from the model [1, T_packed // CP, V // TP].
data – Microbatch data (unpacked, [B, S]).
loss_fn – Loss function (must have input_type == LossInputType.LOGPROB).
cu_seqlens_q – Unpadded cumulative sequence lengths [B+1].
cu_seqlens_q_padded – Padded cumulative sequence lengths [B+1].
vocab_parallel_rank – Vocab parallel rank.
vocab_parallel_group – Vocab parallel group.
context_parallel_group – Context parallel group.
sampling_params – Sampling parameters.
- Returns:
tuple(loss_input, maybe_updated_data)