nemo_automodel.components.datasets.diffusion.sampler#

Module Contents#

Classes#

SequentialBucketSampler

Production-grade Sampler that:

Functions#

collate_fn_production

Production collate function with verification.

build_multiresolution_dataloader

Build production dataloader with sequential bucket iteration and distributed training support.

Data#

API#

nemo_automodel.components.datasets.diffusion.sampler.logger#

β€˜getLogger(…)’

class nemo_automodel.components.datasets.diffusion.sampler.SequentialBucketSampler(
dataset: nemo_automodel.components.datasets.diffusion.text_to_image_dataset.TextToImageDataset,
base_batch_size: int = 32,
base_resolution: Tuple[int, int] = (512, 512),
drop_last: bool = True,
shuffle_buckets: bool = True,
shuffle_within_bucket: bool = True,
dynamic_batch_size: bool = False,
seed: int = 42,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
)#

Bases: torch.utils.data.Sampler[typing.List[int]]

Production-grade Sampler that:

  1. Supports Distributed Data Parallel (DDP) - splits data across GPUs

  2. Deterministic shuffling via torch.Generator (resumable training)

  3. Lazy batch generation (saves RAM compared to pre-computing all batches)

  4. Guarantees equal batch counts across all ranks (prevents DDP deadlocks)

  • Processes all images in bucket A before moving to bucket B

  • Shuffles samples within each bucket (deterministically)

  • Drops incomplete batches at end of each bucket

  • Uses dynamic batch sizes based on resolution

Initialization

Parameters:
  • dataset – TextToImageDataset

  • base_batch_size – Batch size (fixed if dynamic_batch_size=False, or base for scaling if dynamic_batch_size=True)

  • base_resolution – Reference resolution for batch size scaling (only used if dynamic_batch_size=True)

  • drop_last – Drop incomplete batches

  • shuffle_buckets – Shuffle bucket order

  • shuffle_within_bucket – Shuffle samples within each bucket

  • dynamic_batch_size – If True, scale batch size based on resolution. If False (default), use base_batch_size for all buckets.

  • seed – Random seed for deterministic shuffling (resumable training)

  • num_replicas – Number of distributed processes (auto-detected if None)

  • rank – Rank of current process (auto-detected if None)

_get_batch_size(resolution: Tuple[int, int]) int#

Get batch size for resolution (dynamic or fixed based on setting).

_calculate_total_batches() int#

Calculate total batches ensuring ALL ranks get the same count. We pad each bucket to be divisible by (num_replicas * batch_size).

set_epoch(epoch: int)#

Crucial for reproducibility and different shuffles per epoch.

__iter__() Iterator[List[int]]#
__len__() int#
get_batch_info(batch_idx: int) Dict#

Get information about a specific batch.

Note: With lazy evaluation, we don’t pre-compute batches, so this returns bucket-level info for the estimated batch.

nemo_automodel.components.datasets.diffusion.sampler.collate_fn_production(batch: List[Dict]) Dict#

Production collate function with verification.

nemo_automodel.components.datasets.diffusion.sampler.build_multiresolution_dataloader(
*,
dataset: nemo_automodel.components.datasets.diffusion.text_to_image_dataset.TextToImageDataset,
base_batch_size: int,
dp_rank: int,
dp_world_size: int,
base_resolution: Tuple[int, int] = (512, 512),
drop_last: bool = True,
shuffle: bool = True,
dynamic_batch_size: bool = False,
num_workers: int = 4,
pin_memory: bool = True,
prefetch_factor: int = 2,
) Tuple[torch.utils.data.DataLoader, nemo_automodel.components.datasets.diffusion.sampler.SequentialBucketSampler]#

Build production dataloader with sequential bucket iteration and distributed training support.

Parameters:
  • dataset – TextToImageDataset instance

  • base_batch_size – Batch size (fixed, or base for scaling if dynamic_batch_size=True)

  • dp_rank – Rank of current process in data parallel group

  • dp_world_size – Total number of processes in data parallel group

  • base_resolution – Reference resolution (only used if dynamic_batch_size=True)

  • drop_last – Drop incomplete batches

  • shuffle – Shuffle bucket order and samples within buckets each epoch

  • dynamic_batch_size – If True, scale batch size based on resolution. If False (default), use base_batch_size for all buckets.

  • num_workers – Number of data loading workers

  • pin_memory – Pin memory for faster GPU transfer

  • prefetch_factor – How many batches to prefetch per worker

Returns:

Tuple of (DataLoader, SequentialBucketSampler) for production training