nemo_automodel.components.datasets.diffusion.mock_dataloader#
Mock dataloader for automodel WAN training tests.
This module provides a mock dataset and dataloader that generates random tensors with the correct shapes for WAN 2.1 training, allowing functional tests to run without requiring real data.
Module Contents#
Classes#
Mock dataset that generates random data matching WAN 2.1 expected format. |
Functions#
Collate function for mock dataset, matching the real collate_fn behavior. |
|
Build a mock dataloader for WAN training tests. |
API#
- class nemo_automodel.components.datasets.diffusion.mock_dataloader.MockWanDataset(
- length: int = 1024,
- num_channels: int = 16,
- num_frame_latents: int = 16,
- spatial_h: int = 30,
- spatial_w: int = 52,
- text_seq_len: int = 77,
- text_embed_dim: int = 4096,
- device: str = 'cpu',
Bases:
torch.utils.data.DatasetMock dataset that generates random data matching WAN 2.1 expected format.
- Parameters:
length – Number of samples in the dataset.
num_channels – Number of latent channels (default: 16 for WAN).
num_frame_latents – Number of temporal latent frames.
spatial_h – Height of spatial latents.
spatial_w – Width of spatial latents.
text_seq_len – Length of text sequence.
text_embed_dim – Dimension of text embeddings (default: 4096 for UMT5).
device – Device to place tensors on.
Initialization
- __len__() int#
- __getitem__(idx: int) Dict[str, torch.Tensor]#
Generate a mock sample with random data.
- Returns:
text_embeddings: [1, text_seq_len, text_embed_dim]
video_latents: [1, num_channels, num_frame_latents, spatial_h, spatial_w]
metadata: empty dict
file_info: mock file info
- Return type:
Dict containing
- nemo_automodel.components.datasets.diffusion.mock_dataloader.mock_collate_fn(batch)#
Collate function for mock dataset, matching the real collate_fn behavior.
- nemo_automodel.components.datasets.diffusion.mock_dataloader.build_mock_dataloader(
- *,
- dp_rank: int = 0,
- dp_world_size: int = 1,
- batch_size: int = 1,
- num_workers: int = 0,
- device: str = 'cpu',
- length: int = 1024,
- num_channels: int = 16,
- num_frame_latents: int = 16,
- spatial_h: int = 30,
- spatial_w: int = 52,
- text_seq_len: int = 77,
- text_embed_dim: int = 4096,
- shuffle: bool = True,
Build a mock dataloader for WAN training tests.
This function follows the same interface as build_dataloader but generates random data instead of loading from .meta files.
- Parameters:
dp_rank – Data parallel rank.
dp_world_size – Data parallel world size.
batch_size – Batch size per GPU.
num_workers – Number of dataloader workers.
device – Device to place tensors on.
length – Number of samples in mock dataset.
num_channels – Number of latent channels (default: 16).
num_frame_latents – Number of temporal latent frames.
spatial_h – Height of spatial latents.
spatial_w – Width of spatial latents.
text_seq_len – Length of text sequence.
text_embed_dim – Dimension of text embeddings.
shuffle – Whether to shuffle data.
- Returns:
Tuple of (DataLoader, DistributedSampler or None).