nemo_automodel.components.distributed.thd_utils
nemo_automodel.components.distributed.thd_utils
Module Contents
Functions
API
Process inputs for THD (total, hidden, depth) format.
This function converts batched inputs from BSHD (batch, sequence, hidden, depth) format to THD format for packed sequence processing. In THD format, the batch dimension is collapsed and all sequences are concatenated along the sequence dimension. This supports both 2D token IDs and 3D embeddings for pipeline parallelism scenarios.
The function filters out padding values in seq_lens and seq_lens_padded (indicated by seq_lens_padding_value) and computes cumulative sequence lengths for efficient attention computation with Transformer Engine or other packed sequence implementations.
Parameters:
Dictionary containing:
- ‘input_ids’: Input tensor of shape [batch_size, seq_len] for token IDs or [batch_size, seq_len, hidden_dim] for embeddings (in pipeline parallel scenarios)
- ‘labels’: Labels tensor of shape [batch_size, seq_len]
- ‘position_ids’: Position IDs tensor of shape [batch_size, seq_len] (required)
- ‘seq_lens’: Sequence lengths tensor of shape [batch_size, num_packs] containing actual sequence lengths (excluding padding/separators). Values matching seq_lens_padding_value indicate padding and are filtered out.
- ‘seq_lens_padded’: Padded sequence lengths tensor of shape [batch_size, num_packs] containing lengths including separator tokens. Values matching seq_lens_padding_value indicate padding and are filtered out.
Value used to indicate padding in seq_lens/seq_lens_padded tensors that should be filtered out (default: -1000)
Token ID used for padding in input_ids to generate padding_mask (default: 0)
Returns: dict[str, torch.Tensor]
Dictionary containing:
- ‘input_ids’: Reshaped tensor of shape [total_tokens] for 2D token IDs or [total_tokens, hidden_dim] for 3D embeddings
- ‘labels’: Reshaped labels tensor of shape [total_tokens]
- ‘position_ids’: Reshaped tensor of shape [total_tokens]
- ‘cu_seqlens’: Cumulative REAL sequence lengths tensor of shape [num_sequences + 1] (int32)
where num_sequences is the total count of non-padded sequences across the batch.
Built from seq_lens (the unpadded real lengths). When the trailing pack-pad is
purely at the end (cp_size == 1), the last entry is grown to total_tokens to absorb
that pad and avoid TE’s
pad_between_seqs=Truepath; see the absorption block in the function body for the gate. - ‘cu_seqlens_padded’: (optional) Cumulative PADDED sequence lengths tensor of the same
shape as
cu_seqlens. Only emitted when it differs fromcu_seqlensafter absorption (i.e., when padding lives between sub-sequences, which is the CP case). Forwarded to TE ascu_seqlens_q_padded/cu_seqlens_kv_paddedwithpad_between_seqs=Trueso the kernel reads memory offsets from the padded variant while attending only over the real-length slots. - ‘max_seqlen’: Scalar int32 tensor equal to
max(cu_seqlens[i+1] - cu_seqlens[i])after any absorption. Honors TE’s contract thatmax_seqlen_q >= max(cu_seqlens_q[i+1] - cu_seqlens_q[i]). - ‘padding_mask’: Boolean tensor of shape [total_tokens] indicating padding positions
- Non-tensor keys from input batch are preserved (e.g., ‘qkv_format’)
Process inputs for THD format by splitting batch into chunks for context parallelism.
This function splits the batch along the batch dimension into num_chunks chunks, processes each chunk with process_input_for_thd, and stacks the tensor results. This is useful for context parallelism where different chunks are processed on different devices/ranks.
The cu_seqlens tensors from different chunks may have different lengths depending on the number of sequences in each chunk. These are padded with seq_lens_padding_value to ensure uniform length across chunks for stacking.
Parameters:
Dictionary containing input tensors with same structure as process_input_for_thd:
- ‘input_ids’: [batch_size, seq_len] or [batch_size, seq_len, hidden_dim]
- ‘labels’: [batch_size, seq_len]
- ‘position_ids’: [batch_size, seq_len] (required)
- ‘seq_lens’: [batch_size, num_packs]
- ‘seq_lens_padded’: [batch_size, num_packs]
Number of chunks to split the batch into. Must evenly divide batch_size. If num_chunks <= 1, returns the result from process_input_for_thd directly.
Value used to indicate padding in seq_lens/seq_lens_padded tensors and for padding cu_seqlens to uniform length (default: -1000)
Token ID used for padding in input_ids to generate padding_mask (default: 0)
Returns: dict[str, torch.Tensor]
Dictionary containing: