nemo_automodel.components.datasets.diffusion.sampler#
Module Contents#
Classes#
Production-grade Sampler that: |
Data#
API#
- nemo_automodel.components.datasets.diffusion.sampler.logger#
‘getLogger(…)’
- class nemo_automodel.components.datasets.diffusion.sampler.SequentialBucketSampler(
- dataset: nemo_automodel.components.datasets.diffusion.base_dataset.BaseMultiresolutionDataset,
- base_batch_size: int = 32,
- base_resolution: Tuple[int, int] = (512, 512),
- drop_last: bool = True,
- shuffle_buckets: bool = True,
- shuffle_within_bucket: bool = True,
- dynamic_batch_size: bool = False,
- seed: int = 42,
- num_replicas: Optional[int] = None,
- rank: Optional[int] = None,
Bases:
torch.utils.data.Sampler[typing.List[int]]Production-grade Sampler that:
Supports Distributed Data Parallel (DDP) - splits data across GPUs
Deterministic shuffling via torch.Generator (resumable training)
Lazy batch generation (saves RAM compared to pre-computing all batches)
Guarantees equal batch counts across all ranks (prevents DDP deadlocks)
Processes all images in bucket A before moving to bucket B
Shuffles samples within each bucket (deterministically)
Drops incomplete batches at end of each bucket
Uses dynamic batch sizes based on resolution
Initialization
- Parameters:
dataset – BaseMultiresolutionDataset (or any subclass)
base_batch_size – Batch size (fixed if dynamic_batch_size=False, or base for scaling if dynamic_batch_size=True)
base_resolution – Reference resolution for batch size scaling (only used if dynamic_batch_size=True)
drop_last – Drop incomplete batches
shuffle_buckets – Shuffle bucket order
shuffle_within_bucket – Shuffle samples within each bucket
dynamic_batch_size – If True, scale batch size based on resolution. If False (default), use base_batch_size for all buckets.
seed – Random seed for deterministic shuffling (resumable training)
num_replicas – Number of distributed processes (auto-detected if None)
rank – Rank of current process (auto-detected if None)
- _get_batch_size(resolution: Tuple[int, int]) int#
Get batch size for resolution (dynamic or fixed based on setting).
- _calculate_total_batches() int#
Calculate total batches ensuring ALL ranks get the same count. We pad each bucket to be divisible by (num_replicas * batch_size).
- set_epoch(epoch: int)#
Crucial for reproducibility and different shuffles per epoch.
- state_dict() Dict#
Return sampler state for mid-epoch checkpointing.
- load_state_dict(state_dict: Dict) None#
Restore sampler state; the next iter will skip already-yielded batches.
- __iter__() Iterator[List[int]]#
- __len__() int#
- get_batch_info(batch_idx: int) Dict#
Get information about a specific batch.
Note: With lazy evaluation, we don’t pre-compute batches, so this returns bucket-level info for the estimated batch.