nemo_automodel.components.datasets.diffusion.collate_fns#
Flux-compatible collate function that wraps the multiresolution dataloader output to match the FlowMatchingPipeline expected batch format.
Module Contents#
Functions#
Flux-compatible collate function that transforms multiresolution batch output to match FlowMatchingPipeline expected format. |
|
Build a Flux-compatible multiresolution dataloader for TrainDiffusionRecipe. |
Data#
API#
- nemo_automodel.components.datasets.diffusion.collate_fns.logger#
βgetLogger(β¦)β
- nemo_automodel.components.datasets.diffusion.collate_fns.collate_fn_flux(batch: List[Dict]) Dict#
Flux-compatible collate function that transforms multiresolution batch output to match FlowMatchingPipeline expected format.
- Parameters:
batch β List of samples from TextToImageDataset
- Returns:
Dict compatible with FlowMatchingPipeline.step()
- nemo_automodel.components.datasets.diffusion.collate_fns.build_flux_multiresolution_dataloader(
- *,
- cache_dir: str,
- train_text_encoder: bool = False,
- batch_size: int = 1,
- dp_rank: int = 0,
- dp_world_size: int = 1,
- base_resolution: Tuple[int, int] = (256, 256),
- drop_last: bool = True,
- shuffle: bool = True,
- dynamic_batch_size: bool = False,
- num_workers: int = 4,
- pin_memory: bool = True,
- prefetch_factor: int = 2,
Build a Flux-compatible multiresolution dataloader for TrainDiffusionRecipe.
This wraps the existing TextToImageDataset and SequentialBucketSampler with a Flux-compatible collate function.
- Parameters:
cache_dir β Directory containing preprocessed cache (metadata.json, shards, and resolution subdirs)
train_text_encoder β If True, returns tokens instead of embeddings
batch_size β Batch size per GPU
dp_rank β Data parallel rank
dp_world_size β Data parallel world size
base_resolution β Base resolution for dynamic batch sizing
drop_last β Drop incomplete batches
shuffle β Shuffle data
dynamic_batch_size β Scale batch size by resolution
num_workers β DataLoader workers
pin_memory β Pin memory for GPU transfer
prefetch_factor β Prefetch batches per worker
- Returns:
Tuple of (DataLoader, SequentialBucketSampler)