nemo_automodel.components.datasets.diffusion.collate_fns

View as Markdown

Collate functions and dataloader builders for multiresolution diffusion training.

Supports both image and video pipelines via the FlowMatchingPipeline expected batch format.

Module Contents

Functions

NameDescription
_build_multiresolution_dataloader_coreInternal helper: create sampler + DataLoader from dataset and collate fn.
_stack_or_pad_text_tensorsStack text tensors, padding variable sequence lengths on the first dimension.
build_text_to_image_multiresolution_dataloaderBuild a text-to-image multiresolution dataloader for TrainDiffusionRecipe.
build_video_multiresolution_dataloaderBuild a multiresolution video dataloader for TrainDiffusionRecipe.
collate_fn_productionProduction collate function with verification.
collate_fn_text_to_imageText-to-image collate function that transforms multiresolution batch output
collate_fn_videoVideo-compatible collate function for multiresolution video training.

Data

logger

API

nemo_automodel.components.datasets.diffusion.collate_fns._build_multiresolution_dataloader_core(
dataset,
collate_fn: typing.Callable,
batch_size: int,
dp_rank: int,
dp_world_size: int,
base_resolution: typing.Tuple[int, int] = (512, 512),
drop_last: bool = True,
shuffle: bool = True,
dynamic_batch_size: bool = False,
num_workers: int = 4,
pin_memory: bool = True,
prefetch_factor: int = 2
) -> typing.Tuple[torchdata.stateful_dataloader.StatefulDataLoader, nemo_automodel.components.datasets.diffusion.sampler.SequentialBucketSampler]

Internal helper: create sampler + DataLoader from dataset and collate fn.

nemo_automodel.components.datasets.diffusion.collate_fns._stack_or_pad_text_tensors(
tensors: typing.List[torch.Tensor],
sequence_length_multiple: int = 1
) -> torch.Tensor

Stack text tensors, padding variable sequence lengths on the first dimension.

nemo_automodel.components.datasets.diffusion.collate_fns.build_text_to_image_multiresolution_dataloader(
cache_dir: str,
train_text_encoder: bool = False,
batch_size: int = 1,
dp_rank: int = 0,
dp_world_size: int = 1,
base_resolution: typing.Tuple[int, int] = (256, 256),
drop_last: bool = True,
shuffle: bool = True,
dynamic_batch_size: bool = False,
num_workers: int = 4,
pin_memory: bool = True,
prefetch_factor: int = 2
) -> typing.Tuple[torchdata.stateful_dataloader.StatefulDataLoader, nemo_automodel.components.datasets.diffusion.sampler.SequentialBucketSampler]

Build a text-to-image multiresolution dataloader for TrainDiffusionRecipe.

This wraps the existing TextToImageDataset and SequentialBucketSampler with a text-to-image collate function.

Parameters:

cache_dir
str

Directory containing preprocessed cache (metadata.json, shards, and resolution subdirs)

train_text_encoder
boolDefaults to False

If True, returns tokens instead of embeddings

batch_size
intDefaults to 1

Batch size per GPU

dp_rank
intDefaults to 0

Data parallel rank

dp_world_size
intDefaults to 1

Data parallel world size

base_resolution
Tuple[int, int]Defaults to (256, 256)

Base resolution for dynamic batch sizing

drop_last
boolDefaults to True

Drop incomplete batches

shuffle
boolDefaults to True

Shuffle data

dynamic_batch_size
boolDefaults to False

Scale batch size by resolution

num_workers
intDefaults to 4

DataLoader workers

pin_memory
boolDefaults to True

Pin memory for GPU transfer

prefetch_factor
intDefaults to 2

Prefetch batches per worker

Returns: Tuple[StatefulDataLoader, SequentialBucketSampler]

Tuple of (DataLoader, SequentialBucketSampler)

nemo_automodel.components.datasets.diffusion.collate_fns.build_video_multiresolution_dataloader(
cache_dir: str,
model_type: str = 'wan',
device: str = 'cpu',
batch_size: int = 1,
dp_rank: int = 0,
dp_world_size: int = 1,
base_resolution: typing.Tuple[int, int] = (512, 512),
drop_last: bool = True,
shuffle: bool = True,
dynamic_batch_size: bool = False,
num_workers: int = 2,
pin_memory: bool = True,
prefetch_factor: int = 2
) -> typing.Tuple[torchdata.stateful_dataloader.StatefulDataLoader, nemo_automodel.components.datasets.diffusion.sampler.SequentialBucketSampler]

Build a multiresolution video dataloader for TrainDiffusionRecipe.

Uses TextToVideoDataset with SequentialBucketSampler for bucket-based multiresolution video training (e.g. Wan, Hunyuan).

Parameters:

cache_dir
str

Directory containing preprocessed cache (metadata.json + shards + WxH/*.meta)

model_type
strDefaults to 'wan'

Model type (“wan”, “hunyuan”, etc.)

device
strDefaults to 'cpu'

Device to load tensors to

batch_size
intDefaults to 1

Batch size per GPU

dp_rank
intDefaults to 0

Data parallel rank

dp_world_size
intDefaults to 1

Data parallel world size

base_resolution
Tuple[int, int]Defaults to (512, 512)

Base resolution for dynamic batch sizing

drop_last
boolDefaults to True

Drop incomplete batches

shuffle
boolDefaults to True

Shuffle data

dynamic_batch_size
boolDefaults to False

Scale batch size by resolution

num_workers
intDefaults to 2

DataLoader workers

pin_memory
boolDefaults to True

Pin memory for GPU transfer

prefetch_factor
intDefaults to 2

Prefetch batches per worker

Returns: Tuple[StatefulDataLoader, SequentialBucketSampler]

Tuple of (DataLoader, SequentialBucketSampler)

nemo_automodel.components.datasets.diffusion.collate_fns.collate_fn_production(
batch: typing.List[typing.Dict]
) -> typing.Dict

Production collate function with verification.

nemo_automodel.components.datasets.diffusion.collate_fns.collate_fn_text_to_image(
batch: typing.List[typing.Dict]
) -> typing.Dict

Text-to-image collate function that transforms multiresolution batch output to match FlowMatchingPipeline expected format.

Parameters:

batch
List[Dict]

List of samples from TextToImageDataset

Returns: Dict

Dict compatible with FlowMatchingPipeline.step()

nemo_automodel.components.datasets.diffusion.collate_fns.collate_fn_video(
batch: typing.List[typing.Dict],
model_type: str = 'wan'
) -> typing.Dict

Video-compatible collate function for multiresolution video training.

Concatenates video_latents (5D) and text_embeddings (3D) along the batch dim, matching the format expected by FlowMatchingPipeline with SimpleAdapter.

Parameters:

batch
List[Dict]

List of samples from TextToVideoDataset

model_type
strDefaults to 'wan'

Model type for model-specific field handling

Returns: Dict

Dict compatible with FlowMatchingPipeline.step()

nemo_automodel.components.datasets.diffusion.collate_fns.logger = logging.getLogger(__name__)