nemo_automodel.components.distributed.thd_utils
#
Module Contents#
Functions#
Process inputs for THD (total, hidden, depth) format. |
|
Process inputs for THD format by splitting batch into chunks for context parallelism. |
API#
- nemo_automodel.components.distributed.thd_utils.process_input_for_thd(
- batch: dict[str, torch.Tensor],
- seq_lens_padding_value: int = -1000,
- padding_token_id: int = 0,
Process inputs for THD (total, hidden, depth) format.
This function converts batched inputs from BSHD (batch, sequence, hidden, depth) format to THD format for packed sequence processing. In THD format, the batch dimension is collapsed and all sequences are concatenated along the sequence dimension. This supports both 2D token IDs and 3D embeddings for pipeline parallelism scenarios.
The function filters out padding values in seq_lens and seq_lens_padded (indicated by seq_lens_padding_value) and computes cumulative sequence lengths for efficient attention computation with Transformer Engine or other packed sequence implementations.
- Parameters:
batch –
Dictionary containing:
’input_ids’: Input tensor of shape [batch_size, seq_len] for token IDs or [batch_size, seq_len, hidden_dim] for embeddings (in pipeline parallel scenarios)
’labels’: Labels tensor of shape [batch_size, seq_len]
’position_ids’: Position IDs tensor of shape [batch_size, seq_len] (required)
’seq_lens’: Sequence lengths tensor of shape [batch_size, num_packs] containing actual sequence lengths (excluding padding/separators). Values matching seq_lens_padding_value indicate padding and are filtered out.
’seq_lens_padded’: Padded sequence lengths tensor of shape [batch_size, num_packs] containing lengths including separator tokens. Values matching seq_lens_padding_value indicate padding and are filtered out.
seq_lens_padding_value – Value used to indicate padding in seq_lens/seq_lens_padded tensors that should be filtered out (default: -1000)
padding_token_id – Token ID used for padding in input_ids to generate padding_mask (default: 0)
- Returns:
- 'input_ids': Reshaped tensor of shape [total_tokens] for 2D token IDs or [total_tokens, hidden_dim] for 3D embeddings - 'labels': Reshaped labels tensor of shape [total_tokens] - 'position_ids': Reshaped tensor of shape [total_tokens] - 'cu_seqlens': Cumulative padded sequence lengths tensor of shape [num_sequences + 1] (int32) where num_sequences is the total count of non-padded sequences across the batch. NOTE: This contains cumulative lengths from seq_lens_padded (not seq_lens) since CP doesn't support padding between sequences (resulting in NaNs). The labels or loss mask will ensure that loss is computed correctly. - 'padding_mask': Boolean tensor of shape [total_tokens] indicating padding positions - Non-tensor keys from input batch are preserved (e.g., 'qkv_format')
- Return type:
Dictionary containing
.. rubric:: Example
batch_size, seq_len = 2, 6
2D Token IDs case with packed sequences
batch = { … ‘input_ids’: torch.tensor([[1, 2, 3, 99, 4, 5], [6, 7, 8, 9, 10, 11]]), … ‘labels’: torch.tensor([[2, 3, 99, 4, 5, 6], [7, 8, 9, 10, 11, 12]]), … ‘position_ids’: torch.tensor([[0, 1, 2, 0, 0, 1], [0, 1, 2, 3, 4, 5]]), … ‘seq_lens’: torch.tensor([[3, 2], [6, -1000]]), # Second batch has only 1 sequence … ‘seq_lens_padded’: torch.tensor([[4, 2], [6, -1000]]) … }
result = process_input_for_thd(batch)
result[‘input_ids’].shape: [12] (2D input collapsed to 1D)
result[‘labels’].shape: [12]
result[‘position_ids’].shape: [12]
result[‘cu_seqlens’]: tensor([0, 4, 6, 12], dtype=torch.int32)
Breakdown: [0] + cumsum([4, 2, 6]) = [0, 4, 6, 12] (from seq_lens_padded)
result[‘padding_mask’].shape: [12]
- nemo_automodel.components.distributed.thd_utils.split_batch_into_thd_chunks(
- batch: dict[str, torch.Tensor],
- num_chunks: int,
- seq_lens_padding_value: int = -1000,
- padding_token_id: int = 0,
Process inputs for THD format by splitting batch into chunks for context parallelism.
This function splits the batch along the batch dimension into num_chunks chunks, processes each chunk with process_input_for_thd, and stacks the tensor results. This is useful for context parallelism where different chunks are processed on different devices/ranks.
The cu_seqlens tensors from different chunks may have different lengths depending on the number of sequences in each chunk. These are padded with seq_lens_padding_value to ensure uniform length across chunks for stacking.
- Parameters:
batch –
Dictionary containing input tensors with same structure as process_input_for_thd:
’input_ids’: [batch_size, seq_len] or [batch_size, seq_len, hidden_dim]
’labels’: [batch_size, seq_len]
’position_ids’: [batch_size, seq_len] (required)
’seq_lens’: [batch_size, num_packs]
’seq_lens_padded’: [batch_size, num_packs]
num_chunks – Number of chunks to split the batch into. Must evenly divide batch_size. If num_chunks <= 1, returns the result from process_input_for_thd directly.
seq_lens_padding_value – Value used to indicate padding in seq_lens/seq_lens_padded tensors and for padding cu_seqlens to uniform length (default: -1000)
padding_token_id – Token ID used for padding in input_ids to generate padding_mask (default: 0)
- Returns:
When num_chunks > 1:
’input_ids’: [num_chunks, tokens_per_chunk] or [num_chunks, tokens_per_chunk, hidden_dim]
’labels’: [num_chunks, tokens_per_chunk]
’position_ids’: [num_chunks, tokens_per_chunk]
’cu_seqlens’: [num_chunks, max_sequences_per_chunk + 1] (padded with seq_lens_padding_value). Contains cumulative lengths from seq_lens_padded for CP compatibility.
’padding_mask’: [num_chunks, tokens_per_chunk]
Non-tensor keys from input batch are preserved
When num_chunks <= 1: Returns the same format as process_input_for_thd (no chunk dimension)
- Return type:
Dictionary containing
.. rubric:: Example
batch_size, seq_len = 4, 6 batch = { … ‘input_ids’: torch.tensor([[1,2,3,4,5,6], [7,8,9,10,11,12], … [13,14,15,16,17,18], [19,20,21,22,23,24]]), … ‘labels’: torch.tensor([[2,3,4,5,6,7], [8,9,10,11,12,13], … [14,15,16,17,18,19], [20,21,22,23,24,25]]), … ‘position_ids’: torch.tensor([[0,1,2,3,4,5], [0,1,2,3,4,5], … [0,1,2,3,4,5], [0,1,2,3,4,5]]), … ‘seq_lens’: torch.tensor([[6], [6], [6], [6]]), … ‘seq_lens_padded’: torch.tensor([[6], [6], [6], [6]]), … }
result = split_batch_into_thd_chunks(batch, num_chunks=2)
result[‘input_ids’].shape: [2, 12] (2 chunks, each with 2*6=12 tokens)
result[‘cu_seqlens’].shape: [2, 3] (2 chunks, each with [0, 6, 12])
result[‘cu_seqlens’][0]: tensor([0, 6, 12], dtype=torch.int32)
result[‘cu_seqlens’][1]: tensor([0, 6, 12], dtype=torch.int32)