nemo_automodel.components.datasets.diffusion.collate_fns#

Flux-compatible collate function that wraps the multiresolution dataloader output to match the FlowMatchingPipeline expected batch format.

Module Contents#

Functions#

collate_fn_flux

Flux-compatible collate function that transforms multiresolution batch output to match FlowMatchingPipeline expected format.

build_flux_multiresolution_dataloader

Build a Flux-compatible multiresolution dataloader for TrainDiffusionRecipe.

Data#

API#

nemo_automodel.components.datasets.diffusion.collate_fns.logger#

β€˜getLogger(…)’

nemo_automodel.components.datasets.diffusion.collate_fns.collate_fn_flux(batch: List[Dict]) Dict#

Flux-compatible collate function that transforms multiresolution batch output to match FlowMatchingPipeline expected format.

Parameters:

batch – List of samples from TextToImageDataset

Returns:

Dict compatible with FlowMatchingPipeline.step()

nemo_automodel.components.datasets.diffusion.collate_fns.build_flux_multiresolution_dataloader(
*,
cache_dir: str,
train_text_encoder: bool = False,
batch_size: int = 1,
dp_rank: int = 0,
dp_world_size: int = 1,
base_resolution: Tuple[int, int] = (256, 256),
drop_last: bool = True,
shuffle: bool = True,
dynamic_batch_size: bool = False,
num_workers: int = 4,
pin_memory: bool = True,
prefetch_factor: int = 2,
) Tuple[torch.utils.data.DataLoader, nemo_automodel.components.datasets.diffusion.sampler.SequentialBucketSampler]#

Build a Flux-compatible multiresolution dataloader for TrainDiffusionRecipe.

This wraps the existing TextToImageDataset and SequentialBucketSampler with a Flux-compatible collate function.

Parameters:
  • cache_dir – Directory containing preprocessed cache (metadata.json, shards, and resolution subdirs)

  • train_text_encoder – If True, returns tokens instead of embeddings

  • batch_size – Batch size per GPU

  • dp_rank – Data parallel rank

  • dp_world_size – Data parallel world size

  • base_resolution – Base resolution for dynamic batch sizing

  • drop_last – Drop incomplete batches

  • shuffle – Shuffle data

  • dynamic_batch_size – Scale batch size by resolution

  • num_workers – DataLoader workers

  • pin_memory – Pin memory for GPU transfer

  • prefetch_factor – Prefetch batches per worker

Returns:

Tuple of (DataLoader, SequentialBucketSampler)