nemo_automodel.components.datasets.diffusion.sampler

View as Markdown

Module Contents

Classes

NameDescription
SequentialBucketSamplerProduction-grade Sampler that:

Data

logger

API

class nemo_automodel.components.datasets.diffusion.sampler.SequentialBucketSampler(
dataset: nemo_automodel.components.datasets.diffusion.base_dataset.BaseMultiresolutionDataset,
base_batch_size: int = 32,
base_resolution: typing.Tuple[int, int] = (512, 512),
drop_last: bool = True,
shuffle_buckets: bool = True,
shuffle_within_bucket: bool = True,
dynamic_batch_size: bool = False,
seed: int = 42,
num_replicas: typing.Optional[int] = None,
rank: typing.Optional[int] = None
)

Bases: Sampler[List[int]]

Production-grade Sampler that:

  1. Supports Distributed Data Parallel (DDP) - splits data across GPUs
  2. Deterministic shuffling via torch.Generator (resumable training)
  3. Lazy batch generation (saves RAM compared to pre-computing all batches)
  4. Guarantees equal batch counts across all ranks (prevents DDP deadlocks)
  • Processes all images in bucket A before moving to bucket B
  • Shuffles samples within each bucket (deterministically)
  • Drops incomplete batches at end of each bucket
  • Uses dynamic batch sizes based on resolution
_batches_to_skip
= 0
_batches_yielded
= 0
_total_batches
= self._calculate_total_batches()
bucket_groups
= dataset.bucket_groups
bucket_keys
= dataset.sorted_bucket_keys
calculator
= dataset.calculator
epoch
= 0
nemo_automodel.components.datasets.diffusion.sampler.SequentialBucketSampler.__iter__() -> typing.Iterator[typing.List[int]]
nemo_automodel.components.datasets.diffusion.sampler.SequentialBucketSampler.__len__() -> int
nemo_automodel.components.datasets.diffusion.sampler.SequentialBucketSampler._calculate_total_batches() -> int

Calculate total batches ensuring ALL ranks get the same count. We pad each bucket to be divisible by (num_replicas * batch_size).

nemo_automodel.components.datasets.diffusion.sampler.SequentialBucketSampler._get_batch_size(
resolution: typing.Tuple[int, int]
) -> int

Get batch size for resolution (dynamic or fixed based on setting).

nemo_automodel.components.datasets.diffusion.sampler.SequentialBucketSampler.get_batch_info(
batch_idx: int
) -> typing.Dict

Get information about a specific batch.

Note: With lazy evaluation, we don’t pre-compute batches, so this returns bucket-level info for the estimated batch.

nemo_automodel.components.datasets.diffusion.sampler.SequentialBucketSampler.load_state_dict(
state_dict: typing.Dict
) -> None

Restore sampler state; the next iter will skip already-yielded batches.

nemo_automodel.components.datasets.diffusion.sampler.SequentialBucketSampler.set_epoch(
epoch: int
)

Crucial for reproducibility and different shuffles per epoch.

nemo_automodel.components.datasets.diffusion.sampler.SequentialBucketSampler.state_dict() -> typing.Dict

Return sampler state for mid-epoch checkpointing.

nemo_automodel.components.datasets.diffusion.sampler.logger = logging.getLogger(__name__)