nemo_automodel.components.datasets.diffusion.mock_dataloader#

Mock dataloader for automodel WAN training tests.

This module provides a mock dataset and dataloader that generates random tensors with the correct shapes for WAN 2.1 training, allowing functional tests to run without requiring real data.

Module Contents#

Classes#

MockWanDataset

Mock dataset that generates random data matching WAN 2.1 expected format.

Functions#

mock_collate_fn

Collate function for mock dataset, matching the real collate_fn behavior.

build_mock_dataloader

Build a mock dataloader for WAN training tests.

API#

class nemo_automodel.components.datasets.diffusion.mock_dataloader.MockWanDataset(
length: int = 1024,
num_channels: int = 16,
num_frame_latents: int = 16,
spatial_h: int = 30,
spatial_w: int = 52,
text_seq_len: int = 77,
text_embed_dim: int = 4096,
device: str = 'cpu',
)#

Bases: torch.utils.data.Dataset

Mock dataset that generates random data matching WAN 2.1 expected format.

Parameters:
  • length – Number of samples in the dataset.

  • num_channels – Number of latent channels (default: 16 for WAN).

  • num_frame_latents – Number of temporal latent frames.

  • spatial_h – Height of spatial latents.

  • spatial_w – Width of spatial latents.

  • text_seq_len – Length of text sequence.

  • text_embed_dim – Dimension of text embeddings (default: 4096 for UMT5).

  • device – Device to place tensors on.

Initialization

__len__() int#
__getitem__(idx: int) Dict[str, torch.Tensor]#

Generate a mock sample with random data.

Returns:

  • text_embeddings: [1, text_seq_len, text_embed_dim]

  • video_latents: [1, num_channels, num_frame_latents, spatial_h, spatial_w]

  • metadata: empty dict

  • file_info: mock file info

Return type:

Dict containing

nemo_automodel.components.datasets.diffusion.mock_dataloader.mock_collate_fn(batch)#

Collate function for mock dataset, matching the real collate_fn behavior.

nemo_automodel.components.datasets.diffusion.mock_dataloader.build_mock_dataloader(
*,
dp_rank: int = 0,
dp_world_size: int = 1,
batch_size: int = 1,
num_workers: int = 0,
device: str = 'cpu',
length: int = 1024,
num_channels: int = 16,
num_frame_latents: int = 16,
spatial_h: int = 30,
spatial_w: int = 52,
text_seq_len: int = 77,
text_embed_dim: int = 4096,
shuffle: bool = True,
) Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.DistributedSampler]]#

Build a mock dataloader for WAN training tests.

This function follows the same interface as build_dataloader but generates random data instead of loading from .meta files.

Parameters:
  • dp_rank – Data parallel rank.

  • dp_world_size – Data parallel world size.

  • batch_size – Batch size per GPU.

  • num_workers – Number of dataloader workers.

  • device – Device to place tensors on.

  • length – Number of samples in mock dataset.

  • num_channels – Number of latent channels (default: 16).

  • num_frame_latents – Number of temporal latent frames.

  • spatial_h – Height of spatial latents.

  • spatial_w – Width of spatial latents.

  • text_seq_len – Length of text sequence.

  • text_embed_dim – Dimension of text embeddings.

  • shuffle – Whether to shuffle data.

Returns:

Tuple of (DataLoader, DistributedSampler or None).