nemo_automodel.components.datasets.llm.length_grouped_sampler#

Length-grouped sampler for LLM training.

Groups samples by token count so that batches contain similar-length sequences, minimizing padding waste. Adapted from the VLM LengthGroupedSampler but simplified for text-only datasets.

Usage::

sampler = LengthGroupedSampler(
    dataset=ds,
    batch_size=4,
    seed=42,
    num_replicas=world_size,
    rank=rank,
)
dataloader = DataLoader(dataset, sampler=sampler, batch_size=4)

Module Contents#

Classes#

LengthGroupedSampler

Sampler that groups samples by sequence length for balanced batches.

Data#

API#

nemo_automodel.components.datasets.llm.length_grouped_sampler.logger#

‘getLogger(…)’

class nemo_automodel.components.datasets.llm.length_grouped_sampler.LengthGroupedSampler(
dataset: torch.utils.data.Dataset,
batch_size: int = 1,
seed: int = 42,
num_replicas: int | None = None,
rank: int | None = None,
drop_last: bool = True,
)#

Bases: torch.utils.data.Sampler[int]

Sampler that groups samples by sequence length for balanced batches.

Sorts samples by length, chunks into groups of batch_size, then shuffles at the chunk level each epoch. This preserves intra-batch length similarity (less padding) while adding per-epoch randomness.

For distributed training, each rank gets an interleaved shard of the sorted indices. All ranks use the same seed + epoch so chunk K on every rank corresponds to similar-length samples, keeping cross-rank padding minimal.

Parameters:
  • dataset – The dataset to sample from. Samples must have an input_ids key (list or tensor) whose length is used for sorting.

  • batch_size – Local batch size per rank.

  • seed – Base random seed (must be the same on all ranks).

  • num_replicas – Number of distributed ranks (default: world size).

  • rank – This rank’s index (default: current rank).

  • drop_last – Drop the tail indices that don’t fill a full batch across all ranks.

Initialization

static _compute_lengths(dataset: torch.utils.data.Dataset) list[int]#

Compute token lengths for all samples.

set_epoch(epoch: int) None#

Set the epoch for deterministic per-epoch shuffling.

state_dict() Dict[str, Any]#
load_state_dict(state_dict: Dict[str, Any]) None#
__iter__() Iterator[int]#
__len__() int#