nemo_automodel.components.datasets.diffusion.sampler#

Module Contents#

Classes#

SequentialBucketSampler

Production-grade Sampler that:

Data#

API#

nemo_automodel.components.datasets.diffusion.sampler.logger#

‘getLogger(…)’

class nemo_automodel.components.datasets.diffusion.sampler.SequentialBucketSampler(
dataset: nemo_automodel.components.datasets.diffusion.base_dataset.BaseMultiresolutionDataset,
base_batch_size: int = 32,
base_resolution: Tuple[int, int] = (512, 512),
drop_last: bool = True,
shuffle_buckets: bool = True,
shuffle_within_bucket: bool = True,
dynamic_batch_size: bool = False,
seed: int = 42,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
)#

Bases: torch.utils.data.Sampler[typing.List[int]]

Production-grade Sampler that:

  1. Supports Distributed Data Parallel (DDP) - splits data across GPUs

  2. Deterministic shuffling via torch.Generator (resumable training)

  3. Lazy batch generation (saves RAM compared to pre-computing all batches)

  4. Guarantees equal batch counts across all ranks (prevents DDP deadlocks)

  • Processes all images in bucket A before moving to bucket B

  • Shuffles samples within each bucket (deterministically)

  • Drops incomplete batches at end of each bucket

  • Uses dynamic batch sizes based on resolution

Initialization

Parameters:
  • dataset – BaseMultiresolutionDataset (or any subclass)

  • base_batch_size – Batch size (fixed if dynamic_batch_size=False, or base for scaling if dynamic_batch_size=True)

  • base_resolution – Reference resolution for batch size scaling (only used if dynamic_batch_size=True)

  • drop_last – Drop incomplete batches

  • shuffle_buckets – Shuffle bucket order

  • shuffle_within_bucket – Shuffle samples within each bucket

  • dynamic_batch_size – If True, scale batch size based on resolution. If False (default), use base_batch_size for all buckets.

  • seed – Random seed for deterministic shuffling (resumable training)

  • num_replicas – Number of distributed processes (auto-detected if None)

  • rank – Rank of current process (auto-detected if None)

_get_batch_size(resolution: Tuple[int, int]) int#

Get batch size for resolution (dynamic or fixed based on setting).

_calculate_total_batches() int#

Calculate total batches ensuring ALL ranks get the same count. We pad each bucket to be divisible by (num_replicas * batch_size).

set_epoch(epoch: int)#

Crucial for reproducibility and different shuffles per epoch.

state_dict() Dict#

Return sampler state for mid-epoch checkpointing.

load_state_dict(state_dict: Dict) None#

Restore sampler state; the next iter will skip already-yielded batches.

__iter__() Iterator[List[int]]#
__len__() int#
get_batch_info(batch_idx: int) Dict#

Get information about a specific batch.

Note: With lazy evaluation, we don’t pre-compute batches, so this returns bucket-level info for the estimated batch.