nemo_rl.models.megatron.common
#
Module Contents#
Functions#
Pack sequences for Megatron model processing with optional context parallelism. |
|
Unpack sequences from Megatron output format. |
|
Forward training step with support for packed sequences and context parallelism. |
|
Broadcasts a tensor from src_rank to all ranks in the group using broadcast_object_list for metadata. |
API#
- nemo_rl.models.megatron.common._pack_sequences_for_megatron(
- input_ids: torch.Tensor,
- seq_lengths: torch.Tensor,
- pad_individual_seqs_to_multiple_of: int = 1,
- pad_packed_seq_to: Optional[int] = None,
- cp_rank: int = 0,
- cp_size: int = 1,
Pack sequences for Megatron model processing with optional context parallelism.
- Parameters:
input_ids β Input token IDs [batch_size, seq_length]
seq_lengths β Actual sequence lengths for each sample [batch_size]
pad_individual_seqs_to_multiple_of β Pad individual sequences to a multiple of this value
pad_packed_seq_to β Pad packed sequences to this value (before CP)
cp_size β Context parallelism size
- Returns:
packed_input_ids: Packed input tensor [1, T]
input_ids_cp_sharded: Sharded input tensor [cp_size, T // cp_size]
packed_seq_params: PackedSeqParams object
cu_seqlens: Cumulative sequence lengths
cu_seqlens_padded: Padded cumulative sequence lengths
- Return type:
Tuple of
- nemo_rl.models.megatron.common._unpack_sequences_from_megatron(
- output_tensor: torch.Tensor,
- seq_lengths: torch.Tensor,
- cu_seqlens: torch.Tensor,
- cu_seqlens_padded: Optional[torch.Tensor],
- original_batch_size: int,
- original_seq_length: int,
Unpack sequences from Megatron output format.
- Parameters:
output_tensor β Packed output tensor [1, T, vocab_size]
seq_lengths β Actual sequence lengths for each sample
cu_seqlens β Cumulative sequence lengths
cu_seqlens_padded β Padded cumulative sequence lengths (if CP was used)
original_batch_size β Original batch size
original_seq_length β Original maximum sequence length
- Returns:
Unpacked output tensor [batch_size, seq_length, vocab_size]
- nemo_rl.models.megatron.common.forward_step_arbitrary_loss(
- state: nemo.tron.state.GlobalState,
- global_valid_seqs: torch.Tensor,
- global_valid_toks: torch.Tensor,
- data_iterator: Iterator[nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any]],
- model: megatron.core.models.gpt.GPTModel,
- loss_fn: nemo_rl.algorithms.loss_functions.LossFunction,
- pack_sequences: bool = False,
- seq_length_key: Optional[str] = None,
- pad_individual_seqs_to_multiple_of: int = 1,
- pad_full_seq_to: Optional[int] = None,
- cp_normalize: bool = True,
Forward training step with support for packed sequences and context parallelism.
- Parameters:
state (GlobalState) β Global state for the run
global_valid_seqs β Global count of valid sequences
global_valid_toks β Global count of valid tokens
data_iterator β Input data iterator
model (GPTModel) β The GPT Model
loss_fn (LossFunction) β Loss function to apply
pack_sequences (bool) β Whether to pack sequences for efficiency
seq_length_key (Optional[str]) β Key in data_dict containing actual sequence lengths
cp_normalize (bool) β Whether to normalize the loss by the cp_size
Notes on packed sequences with context parallelism (CP): - When CP > 1, each sequence is padded to a multiple of (cp_size * 2) - The factor of 2 ensures load balancing for causal attention - cu_seqlens tracks actual sequence boundaries - cu_seqlens_padded tracks padded sequence boundaries for CP - Requires TransformerEngine >= 1.10 for CP support
- nemo_rl.models.megatron.common.broadcast_tensor(
- tensor: torch.Tensor | None,
- src_rank: int,
- group: torch.distributed.ProcessGroup,
Broadcasts a tensor from src_rank to all ranks in the group using broadcast_object_list for metadata.
Handles the case where the input tensor might be None on non-source ranks. If the input tensor is provided on non-source ranks, it must have the correct shape and dtype matching the tensor on the source rank.
- Parameters:
tensor β The tensor to broadcast on the source rank. Can be None on non-source ranks (will be created with correct shape/dtype). If not None on non-source ranks, itβs used as the buffer for the broadcast and must match the source tensorβs metadata.
src_rank (int) β The global rank of the source process.
group β The process group for communication.
- Returns:
The broadcasted tensor. On non-source ranks, this will be the tensor received from the source.
- Return type:
torch.Tensor
- Raises:
ValueError β If the tensor is None on the source rank, or if a tensor provided on a non-source rank has mismatched shape/dtype/device.
TypeError β If broadcasting metadata fails (e.g., due to pickling issues).