nemo_automodel.components.datasets.diffusion.mock_dataloader

View as Markdown

Mock dataloader for automodel WAN training tests.

This module provides a mock dataset and dataloader that generates random tensors with the correct shapes for WAN 2.1 training, allowing functional tests to run without requiring real data.

Module Contents

Classes

NameDescription
MockWanDatasetMock dataset that generates random data matching WAN 2.1 expected format.

Functions

NameDescription
build_mock_dataloaderBuild a mock dataloader for WAN training tests.
mock_collate_fnCollate function for mock dataset, matching the real collate_fn behavior.

API

class nemo_automodel.components.datasets.diffusion.mock_dataloader.MockWanDataset(
length: int = 1024,
num_channels: int = 16,
num_frame_latents: int = 16,
spatial_h: int = 30,
spatial_w: int = 52,
text_seq_len: int = 77,
text_embed_dim: int = 4096,
device: str = 'cpu'
)

Bases: Dataset

Mock dataset that generates random data matching WAN 2.1 expected format.

Parameters:

length
intDefaults to 1024

Number of samples in the dataset.

num_channels
intDefaults to 16

Number of latent channels (default: 16 for WAN).

num_frame_latents
intDefaults to 16

Number of temporal latent frames.

spatial_h
intDefaults to 30

Height of spatial latents.

spatial_w
intDefaults to 52

Width of spatial latents.

text_seq_len
intDefaults to 77

Length of text sequence.

text_embed_dim
intDefaults to 4096

Dimension of text embeddings (default: 4096 for UMT5).

device
strDefaults to 'cpu'

Device to place tensors on.

length
= max(int(length), 1)
nemo_automodel.components.datasets.diffusion.mock_dataloader.MockWanDataset.__getitem__(
idx: int
) -> typing.Dict[str, torch.Tensor]

Generate a mock sample with random data.

Returns: Dict[str, torch.Tensor]

Dict containing:

  • text_embeddings: [1, text_seq_len, text_embed_dim]
  • video_latents: [1, num_channels, num_frame_latents, spatial_h, spatial_w]
  • metadata: empty dict
  • file_info: mock file info
nemo_automodel.components.datasets.diffusion.mock_dataloader.MockWanDataset.__len__() -> int
nemo_automodel.components.datasets.diffusion.mock_dataloader.build_mock_dataloader(
dp_rank: int = 0,
dp_world_size: int = 1,
batch_size: int = 1,
num_workers: int = 0,
device: str = 'cpu',
length: int = 1024,
num_channels: int = 16,
num_frame_latents: int = 16,
spatial_h: int = 30,
spatial_w: int = 52,
text_seq_len: int = 77,
text_embed_dim: int = 4096,
shuffle: bool = True
) -> typing.Tuple[torch.utils.data.DataLoader, typing.Optional[torch.utils.data.DistributedSampler]]

Build a mock dataloader for WAN training tests.

This function follows the same interface as build_dataloader but generates random data instead of loading from .meta files.

Parameters:

dp_rank
intDefaults to 0

Data parallel rank.

dp_world_size
intDefaults to 1

Data parallel world size.

batch_size
intDefaults to 1

Batch size per GPU.

num_workers
intDefaults to 0

Number of dataloader workers.

device
strDefaults to 'cpu'

Device to place tensors on.

length
intDefaults to 1024

Number of samples in mock dataset.

num_channels
intDefaults to 16

Number of latent channels (default: 16).

num_frame_latents
intDefaults to 16

Number of temporal latent frames.

spatial_h
intDefaults to 30

Height of spatial latents.

spatial_w
intDefaults to 52

Width of spatial latents.

text_seq_len
intDefaults to 77

Length of text sequence.

text_embed_dim
intDefaults to 4096

Dimension of text embeddings.

shuffle
boolDefaults to True

Whether to shuffle data.

Returns: Tuple[DataLoader, Optional[DistributedSampler]]

Tuple of (DataLoader, DistributedSampler or None).

nemo_automodel.components.datasets.diffusion.mock_dataloader.mock_collate_fn(
batch
)

Collate function for mock dataset, matching the real collate_fn behavior.