nemo_automodel.components.datasets.diffusion.sampler#
Module Contents#
Classes#
Production-grade Sampler that: |
Functions#
Production collate function with verification. |
|
Build production dataloader with sequential bucket iteration and distributed training support. |
Data#
API#
- nemo_automodel.components.datasets.diffusion.sampler.logger#
βgetLogger(β¦)β
- class nemo_automodel.components.datasets.diffusion.sampler.SequentialBucketSampler(
- dataset: nemo_automodel.components.datasets.diffusion.text_to_image_dataset.TextToImageDataset,
- base_batch_size: int = 32,
- base_resolution: Tuple[int, int] = (512, 512),
- drop_last: bool = True,
- shuffle_buckets: bool = True,
- shuffle_within_bucket: bool = True,
- dynamic_batch_size: bool = False,
- seed: int = 42,
- num_replicas: Optional[int] = None,
- rank: Optional[int] = None,
Bases:
torch.utils.data.Sampler[typing.List[int]]Production-grade Sampler that:
Supports Distributed Data Parallel (DDP) - splits data across GPUs
Deterministic shuffling via torch.Generator (resumable training)
Lazy batch generation (saves RAM compared to pre-computing all batches)
Guarantees equal batch counts across all ranks (prevents DDP deadlocks)
Processes all images in bucket A before moving to bucket B
Shuffles samples within each bucket (deterministically)
Drops incomplete batches at end of each bucket
Uses dynamic batch sizes based on resolution
Initialization
- Parameters:
dataset β TextToImageDataset
base_batch_size β Batch size (fixed if dynamic_batch_size=False, or base for scaling if dynamic_batch_size=True)
base_resolution β Reference resolution for batch size scaling (only used if dynamic_batch_size=True)
drop_last β Drop incomplete batches
shuffle_buckets β Shuffle bucket order
shuffle_within_bucket β Shuffle samples within each bucket
dynamic_batch_size β If True, scale batch size based on resolution. If False (default), use base_batch_size for all buckets.
seed β Random seed for deterministic shuffling (resumable training)
num_replicas β Number of distributed processes (auto-detected if None)
rank β Rank of current process (auto-detected if None)
- _get_batch_size(resolution: Tuple[int, int]) int#
Get batch size for resolution (dynamic or fixed based on setting).
- _calculate_total_batches() int#
Calculate total batches ensuring ALL ranks get the same count. We pad each bucket to be divisible by (num_replicas * batch_size).
- set_epoch(epoch: int)#
Crucial for reproducibility and different shuffles per epoch.
- __iter__() Iterator[List[int]]#
- __len__() int#
- get_batch_info(batch_idx: int) Dict#
Get information about a specific batch.
Note: With lazy evaluation, we donβt pre-compute batches, so this returns bucket-level info for the estimated batch.
- nemo_automodel.components.datasets.diffusion.sampler.collate_fn_production(batch: List[Dict]) Dict#
Production collate function with verification.
- nemo_automodel.components.datasets.diffusion.sampler.build_multiresolution_dataloader(
- *,
- dataset: nemo_automodel.components.datasets.diffusion.text_to_image_dataset.TextToImageDataset,
- base_batch_size: int,
- dp_rank: int,
- dp_world_size: int,
- base_resolution: Tuple[int, int] = (512, 512),
- drop_last: bool = True,
- shuffle: bool = True,
- dynamic_batch_size: bool = False,
- num_workers: int = 4,
- pin_memory: bool = True,
- prefetch_factor: int = 2,
Build production dataloader with sequential bucket iteration and distributed training support.
- Parameters:
dataset β TextToImageDataset instance
base_batch_size β Batch size (fixed, or base for scaling if dynamic_batch_size=True)
dp_rank β Rank of current process in data parallel group
dp_world_size β Total number of processes in data parallel group
base_resolution β Reference resolution (only used if dynamic_batch_size=True)
drop_last β Drop incomplete batches
shuffle β Shuffle bucket order and samples within buckets each epoch
dynamic_batch_size β If True, scale batch size based on resolution. If False (default), use base_batch_size for all buckets.
num_workers β Number of data loading workers
pin_memory β Pin memory for faster GPU transfer
prefetch_factor β How many batches to prefetch per worker
- Returns:
Tuple of (DataLoader, SequentialBucketSampler) for production training