nemo_automodel.components.datasets.diffusion.mock_dataloader
nemo_automodel.components.datasets.diffusion.mock_dataloader
Mock dataloader for automodel WAN training tests.
This module provides a mock dataset and dataloader that generates random tensors with the correct shapes for WAN 2.1 training, allowing functional tests to run without requiring real data.
Module Contents
Classes
Functions
API
Bases: Dataset
Mock dataset that generates random data matching WAN 2.1 expected format.
Parameters:
Number of samples in the dataset.
Number of latent channels (default: 16 for WAN).
Number of temporal latent frames.
Height of spatial latents.
Width of spatial latents.
Length of text sequence.
Dimension of text embeddings (default: 4096 for UMT5).
Device to place tensors on.
Generate a mock sample with random data.
Returns: Dict[str, torch.Tensor]
Dict containing:
- text_embeddings: [1, text_seq_len, text_embed_dim]
- video_latents: [1, num_channels, num_frame_latents, spatial_h, spatial_w]
- metadata: empty dict
- file_info: mock file info
Build a mock dataloader for WAN training tests.
This function follows the same interface as build_dataloader but generates random data instead of loading from .meta files.
Parameters:
Data parallel rank.
Data parallel world size.
Batch size per GPU.
Number of dataloader workers.
Device to place tensors on.
Number of samples in mock dataset.
Number of latent channels (default: 16).
Number of temporal latent frames.
Height of spatial latents.
Width of spatial latents.
Length of text sequence.
Dimension of text embeddings.
Whether to shuffle data.
Returns: Tuple[DataLoader, Optional[DistributedSampler]]
Tuple of (DataLoader, DistributedSampler or None).
Collate function for mock dataset, matching the real collate_fn behavior.