nemo_automodel.components.datasets.llm.length_grouped_sampler#
Length-grouped sampler for LLM training.
Groups samples by token count so that batches contain similar-length
sequences, minimizing padding waste. Adapted from the VLM
LengthGroupedSampler but simplified for text-only datasets.
Usage::
sampler = LengthGroupedSampler(
dataset=ds,
batch_size=4,
seed=42,
num_replicas=world_size,
rank=rank,
)
dataloader = DataLoader(dataset, sampler=sampler, batch_size=4)
Module Contents#
Classes#
Sampler that groups samples by sequence length for balanced batches. |
Data#
API#
- nemo_automodel.components.datasets.llm.length_grouped_sampler.logger#
‘getLogger(…)’
- class nemo_automodel.components.datasets.llm.length_grouped_sampler.LengthGroupedSampler(
- dataset: torch.utils.data.Dataset,
- batch_size: int = 1,
- seed: int = 42,
- num_replicas: int | None = None,
- rank: int | None = None,
- drop_last: bool = True,
Bases:
torch.utils.data.Sampler[int]Sampler that groups samples by sequence length for balanced batches.
Sorts samples by length, chunks into groups of
batch_size, then shuffles at the chunk level each epoch. This preserves intra-batch length similarity (less padding) while adding per-epoch randomness.For distributed training, each rank gets an interleaved shard of the sorted indices. All ranks use the same
seed + epochso chunk K on every rank corresponds to similar-length samples, keeping cross-rank padding minimal.- Parameters:
dataset – The dataset to sample from. Samples must have an
input_idskey (list or tensor) whose length is used for sorting.batch_size – Local batch size per rank.
seed – Base random seed (must be the same on all ranks).
num_replicas – Number of distributed ranks (default: world size).
rank – This rank’s index (default: current rank).
drop_last – Drop the tail indices that don’t fill a full batch across all ranks.
Initialization
- static _compute_lengths(dataset: torch.utils.data.Dataset) list[int]#
Compute token lengths for all samples.
- set_epoch(epoch: int) None#
Set the epoch for deterministic per-epoch shuffling.
- state_dict() Dict[str, Any]#
- load_state_dict(state_dict: Dict[str, Any]) None#
- __iter__() Iterator[int]#
- __len__() int#