nemo_rl.models.megatron.common#

Module Contents#

Functions#

_pack_sequences_for_megatron

Pack sequences for Megatron model processing with optional context parallelism.

_unpack_sequences_from_megatron

Unpack sequences from Megatron output format.

forward_step_arbitrary_loss

Forward training step with support for packed sequences and context parallelism.

broadcast_tensor

Broadcasts a tensor from src_rank to all ranks in the group using broadcast_object_list for metadata.

API#

nemo_rl.models.megatron.common._pack_sequences_for_megatron(
input_ids: torch.Tensor,
seq_lengths: torch.Tensor,
pad_individual_seqs_to_multiple_of: int = 1,
pad_packed_seq_to: Optional[int] = None,
cp_rank: int = 0,
cp_size: int = 1,
) tuple[torch.Tensor, megatron.core.packed_seq_params.PackedSeqParams, torch.Tensor, Optional[torch.Tensor]]#

Pack sequences for Megatron model processing with optional context parallelism.

Parameters:
  • input_ids – Input token IDs [batch_size, seq_length]

  • seq_lengths – Actual sequence lengths for each sample [batch_size]

  • pad_individual_seqs_to_multiple_of – Pad individual sequences to a multiple of this value

  • pad_packed_seq_to – Pad packed sequences to this value (before CP)

  • cp_size – Context parallelism size

Returns:

  • packed_input_ids: Packed input tensor [1, T]

  • input_ids_cp_sharded: Sharded input tensor [cp_size, T // cp_size]

  • packed_seq_params: PackedSeqParams object

  • cu_seqlens: Cumulative sequence lengths

  • cu_seqlens_padded: Padded cumulative sequence lengths

Return type:

Tuple of

nemo_rl.models.megatron.common._unpack_sequences_from_megatron(
output_tensor: torch.Tensor,
seq_lengths: torch.Tensor,
cu_seqlens: torch.Tensor,
cu_seqlens_padded: Optional[torch.Tensor],
original_batch_size: int,
original_seq_length: int,
) torch.Tensor#

Unpack sequences from Megatron output format.

Parameters:
  • output_tensor – Packed output tensor [1, T, vocab_size]

  • seq_lengths – Actual sequence lengths for each sample

  • cu_seqlens – Cumulative sequence lengths

  • cu_seqlens_padded – Padded cumulative sequence lengths (if CP was used)

  • original_batch_size – Original batch size

  • original_seq_length – Original maximum sequence length

Returns:

Unpacked output tensor [batch_size, seq_length, vocab_size]

nemo_rl.models.megatron.common.forward_step_arbitrary_loss(
state: nemo.tron.state.GlobalState,
global_valid_seqs: torch.Tensor,
global_valid_toks: torch.Tensor,
data_iterator: Iterator[nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any]],
model: megatron.core.models.gpt.GPTModel,
loss_fn: nemo_rl.algorithms.loss_functions.LossFunction,
pack_sequences: bool = False,
seq_length_key: Optional[str] = None,
pad_individual_seqs_to_multiple_of: int = 1,
pad_full_seq_to: Optional[int] = None,
cp_normalize: bool = True,
)#

Forward training step with support for packed sequences and context parallelism.

Parameters:
  • state (GlobalState) – Global state for the run

  • global_valid_seqs – Global count of valid sequences

  • global_valid_toks – Global count of valid tokens

  • data_iterator – Input data iterator

  • model (GPTModel) – The GPT Model

  • loss_fn (LossFunction) – Loss function to apply

  • pack_sequences (bool) – Whether to pack sequences for efficiency

  • seq_length_key (Optional[str]) – Key in data_dict containing actual sequence lengths

  • cp_normalize (bool) – Whether to normalize the loss by the cp_size

Notes on packed sequences with context parallelism (CP): - When CP > 1, each sequence is padded to a multiple of (cp_size * 2) - The factor of 2 ensures load balancing for causal attention - cu_seqlens tracks actual sequence boundaries - cu_seqlens_padded tracks padded sequence boundaries for CP - Requires TransformerEngine >= 1.10 for CP support

nemo_rl.models.megatron.common.broadcast_tensor(
tensor: torch.Tensor | None,
src_rank: int,
group: torch.distributed.ProcessGroup,
) torch.Tensor#

Broadcasts a tensor from src_rank to all ranks in the group using broadcast_object_list for metadata.

Handles the case where the input tensor might be None on non-source ranks. If the input tensor is provided on non-source ranks, it must have the correct shape and dtype matching the tensor on the source rank.

Parameters:
  • tensor – The tensor to broadcast on the source rank. Can be None on non-source ranks (will be created with correct shape/dtype). If not None on non-source ranks, it’s used as the buffer for the broadcast and must match the source tensor’s metadata.

  • src_rank (int) – The global rank of the source process.

  • group – The process group for communication.

Returns:

The broadcasted tensor. On non-source ranks, this will be the tensor received from the source.

Return type:

torch.Tensor

Raises:
  • ValueError – If the tensor is None on the source rank, or if a tensor provided on a non-source rank has mismatched shape/dtype/device.

  • TypeError – If broadcasting metadata fails (e.g., due to pickling issues).