nemo_automodel.components.datasets.diffusion.collate_fns
nemo_automodel.components.datasets.diffusion.collate_fns
Collate functions and dataloader builders for multiresolution diffusion training.
Supports both image and video pipelines via the FlowMatchingPipeline expected batch format.
Module Contents
Functions
Data
API
Internal helper: create sampler + DataLoader from dataset and collate fn.
Stack text tensors, padding variable sequence lengths on the first dimension.
Build a text-to-image multiresolution dataloader for TrainDiffusionRecipe.
This wraps the existing TextToImageDataset and SequentialBucketSampler with a text-to-image collate function.
Parameters:
Directory containing preprocessed cache (metadata.json, shards, and resolution subdirs)
If True, returns tokens instead of embeddings
Batch size per GPU
Data parallel rank
Data parallel world size
Base resolution for dynamic batch sizing
Drop incomplete batches
Shuffle data
Scale batch size by resolution
DataLoader workers
Pin memory for GPU transfer
Prefetch batches per worker
Returns: Tuple[StatefulDataLoader, SequentialBucketSampler]
Tuple of (DataLoader, SequentialBucketSampler)
Build a multiresolution video dataloader for TrainDiffusionRecipe.
Uses TextToVideoDataset with SequentialBucketSampler for bucket-based multiresolution video training (e.g. Wan, Hunyuan).
Parameters:
Directory containing preprocessed cache (metadata.json + shards + WxH/*.meta)
Model type (“wan”, “hunyuan”, etc.)
Device to load tensors to
Batch size per GPU
Data parallel rank
Data parallel world size
Base resolution for dynamic batch sizing
Drop incomplete batches
Shuffle data
Scale batch size by resolution
DataLoader workers
Pin memory for GPU transfer
Prefetch batches per worker
Returns: Tuple[StatefulDataLoader, SequentialBucketSampler]
Tuple of (DataLoader, SequentialBucketSampler)
Production collate function with verification.
Text-to-image collate function that transforms multiresolution batch output to match FlowMatchingPipeline expected format.
Parameters:
List of samples from TextToImageDataset
Returns: Dict
Dict compatible with FlowMatchingPipeline.step()
Video-compatible collate function for multiresolution video training.
Concatenates video_latents (5D) and text_embeddings (3D) along the batch dim, matching the format expected by FlowMatchingPipeline with SimpleAdapter.
Parameters:
List of samples from TextToVideoDataset
Model type for model-specific field handling
Returns: Dict
Dict compatible with FlowMatchingPipeline.step()