nemo_automodel.components.distributed.thd_utils

View as Markdown

Module Contents

Functions

NameDescription
process_input_for_thdProcess inputs for THD (total, hidden, depth) format.
split_batch_into_thd_chunksProcess inputs for THD format by splitting batch into chunks for context parallelism.

API

nemo_automodel.components.distributed.thd_utils.process_input_for_thd(
batch: dict[str, torch.Tensor],
seq_lens_padding_value: int = -1000,
padding_token_id: int = 0
) -> dict[str, torch.Tensor]

Process inputs for THD (total, hidden, depth) format.

This function converts batched inputs from BSHD (batch, sequence, hidden, depth) format to THD format for packed sequence processing. In THD format, the batch dimension is collapsed and all sequences are concatenated along the sequence dimension. This supports both 2D token IDs and 3D embeddings for pipeline parallelism scenarios.

The function filters out padding values in seq_lens and seq_lens_padded (indicated by seq_lens_padding_value) and computes cumulative sequence lengths for efficient attention computation with Transformer Engine or other packed sequence implementations.

Parameters:

batch
dict[str, torch.Tensor]

Dictionary containing:

  • ‘input_ids’: Input tensor of shape [batch_size, seq_len] for token IDs or [batch_size, seq_len, hidden_dim] for embeddings (in pipeline parallel scenarios)
  • ‘labels’: Labels tensor of shape [batch_size, seq_len]
  • ‘position_ids’: Position IDs tensor of shape [batch_size, seq_len] (required)
  • ‘seq_lens’: Sequence lengths tensor of shape [batch_size, num_packs] containing actual sequence lengths (excluding padding/separators). Values matching seq_lens_padding_value indicate padding and are filtered out.
  • ‘seq_lens_padded’: Padded sequence lengths tensor of shape [batch_size, num_packs] containing lengths including separator tokens. Values matching seq_lens_padding_value indicate padding and are filtered out.
seq_lens_padding_value
intDefaults to -1000

Value used to indicate padding in seq_lens/seq_lens_padded tensors that should be filtered out (default: -1000)

padding_token_id
intDefaults to 0

Token ID used for padding in input_ids to generate padding_mask (default: 0)

Returns: dict[str, torch.Tensor]

Dictionary containing:

  • ‘input_ids’: Reshaped tensor of shape [total_tokens] for 2D token IDs or [total_tokens, hidden_dim] for 3D embeddings
  • ‘labels’: Reshaped labels tensor of shape [total_tokens]
  • ‘position_ids’: Reshaped tensor of shape [total_tokens]
  • ‘cu_seqlens’: Cumulative REAL sequence lengths tensor of shape [num_sequences + 1] (int32) where num_sequences is the total count of non-padded sequences across the batch. Built from seq_lens (the unpadded real lengths). When the trailing pack-pad is purely at the end (cp_size == 1), the last entry is grown to total_tokens to absorb that pad and avoid TE’s pad_between_seqs=True path; see the absorption block in the function body for the gate.
  • ‘cu_seqlens_padded’: (optional) Cumulative PADDED sequence lengths tensor of the same shape as cu_seqlens. Only emitted when it differs from cu_seqlens after absorption (i.e., when padding lives between sub-sequences, which is the CP case). Forwarded to TE as cu_seqlens_q_padded / cu_seqlens_kv_padded with pad_between_seqs=True so the kernel reads memory offsets from the padded variant while attending only over the real-length slots.
  • ‘max_seqlen’: Scalar int32 tensor equal to max(cu_seqlens[i+1] - cu_seqlens[i]) after any absorption. Honors TE’s contract that max_seqlen_q >= max(cu_seqlens_q[i+1] - cu_seqlens_q[i]).
  • ‘padding_mask’: Boolean tensor of shape [total_tokens] indicating padding positions
  • Non-tensor keys from input batch are preserved (e.g., ‘qkv_format’)
nemo_automodel.components.distributed.thd_utils.split_batch_into_thd_chunks(
batch: dict[str, torch.Tensor],
num_chunks: int,
seq_lens_padding_value: int = -1000,
padding_token_id: int = 0
) -> dict[str, torch.Tensor]

Process inputs for THD format by splitting batch into chunks for context parallelism.

This function splits the batch along the batch dimension into num_chunks chunks, processes each chunk with process_input_for_thd, and stacks the tensor results. This is useful for context parallelism where different chunks are processed on different devices/ranks.

The cu_seqlens tensors from different chunks may have different lengths depending on the number of sequences in each chunk. These are padded with seq_lens_padding_value to ensure uniform length across chunks for stacking.

Parameters:

batch
dict[str, torch.Tensor]

Dictionary containing input tensors with same structure as process_input_for_thd:

  • ‘input_ids’: [batch_size, seq_len] or [batch_size, seq_len, hidden_dim]
  • ‘labels’: [batch_size, seq_len]
  • ‘position_ids’: [batch_size, seq_len] (required)
  • ‘seq_lens’: [batch_size, num_packs]
  • ‘seq_lens_padded’: [batch_size, num_packs]
num_chunks
int

Number of chunks to split the batch into. Must evenly divide batch_size. If num_chunks <= 1, returns the result from process_input_for_thd directly.

seq_lens_padding_value
intDefaults to -1000

Value used to indicate padding in seq_lens/seq_lens_padded tensors and for padding cu_seqlens to uniform length (default: -1000)

padding_token_id
intDefaults to 0

Token ID used for padding in input_ids to generate padding_mask (default: 0)

Returns: dict[str, torch.Tensor]

Dictionary containing: