nemo_automodel.components.datasets.llm.length_grouped_sampler

View as Markdown

Length-grouped sampler for LLM training.

Groups samples by token count so that batches contain similar-length sequences, minimizing padding waste. Adapted from the VLM LengthGroupedSampler but simplified for text-only datasets.

Usage::

sampler = LengthGroupedSampler( dataset=ds, batch_size=4, seed=42, num_replicas=world_size, rank=rank, ) dataloader = DataLoader(dataset, sampler=sampler, batch_size=4)

Module Contents

Classes

NameDescription
LengthGroupedSamplerSampler that groups samples by sequence length for balanced batches.

Data

logger

API

class nemo_automodel.components.datasets.llm.length_grouped_sampler.LengthGroupedSampler(
dataset: torch.utils.data.Dataset,
batch_size: int = 1,
seed: int = 42,
num_replicas: int | None = None,
rank: int | None = None,
drop_last: bool = True
)

Bases: Sampler[int]

Sampler that groups samples by sequence length for balanced batches.

Sorts samples by length, chunks into groups of batch_size, then shuffles at the chunk level each epoch. This preserves intra-batch length similarity (less padding) while adding per-epoch randomness.

For distributed training, each rank gets an interleaved shard of the sorted indices. All ranks use the same seed + epoch so chunk K on every rank corresponds to similar-length samples, keeping cross-rank padding minimal.

Parameters:

dataset
Dataset

The dataset to sample from. Samples must have an input_ids key (list or tensor) whose length is used for sorting.

batch_size
intDefaults to 1

Local batch size per rank.

seed
intDefaults to 42

Base random seed (must be the same on all ranks).

num_replicas
int | NoneDefaults to None

Number of distributed ranks (default: world size).

rank
int | NoneDefaults to None

This rank’s index (default: current rank).

drop_last
boolDefaults to True

Drop the tail indices that don’t fill a full batch across all ranks.

_next_yielded
int | None = None
batch_size
= max(1, batch_size)
epoch
= 0
lengths
= self._compute_lengths(dataset)
sorted_indices
= sorted_all[(self.rank)::(self.num_replicas)]
yielded
= 0
nemo_automodel.components.datasets.llm.length_grouped_sampler.LengthGroupedSampler.__iter__() -> typing.Iterator[int]
nemo_automodel.components.datasets.llm.length_grouped_sampler.LengthGroupedSampler.__len__() -> int
nemo_automodel.components.datasets.llm.length_grouped_sampler.LengthGroupedSampler._compute_lengths(
dataset: torch.utils.data.Dataset
) -> list[int]
staticmethod

Compute token lengths for all samples.

nemo_automodel.components.datasets.llm.length_grouped_sampler.LengthGroupedSampler.load_state_dict(
state_dict: typing.Dict[str, typing.Any]
) -> None
nemo_automodel.components.datasets.llm.length_grouped_sampler.LengthGroupedSampler.set_epoch(
epoch: int
) -> None

Set the epoch for deterministic per-epoch shuffling.

nemo_automodel.components.datasets.llm.length_grouped_sampler.LengthGroupedSampler.state_dict() -> typing.Dict[str, typing.Any]
nemo_automodel.components.datasets.llm.length_grouped_sampler.logger = logging.getLogger(__name__)