bridge.diffusion.data.wan.wan_mock_datamodule#

Module Contents#

Classes#

Functions#

mock_batch

_mock_collate_fn

Return a picklable collate function that calls mock_batch with fixed kwargs.

_collate_ignore_samples

Collate function that ignores samples and delegates to mock_batch.

API#

class bridge.diffusion.data.wan.wan_mock_datamodule._MockDataset(length: int)#

Bases: torch.utils.data.Dataset

Initialization

__len__() int#
__getitem__(idx: int) dict#
bridge.diffusion.data.wan.wan_mock_datamodule.mock_batch(
F_latents: int,
H_latents: int,
W_latents: int,
patch_temporal: int,
patch_spatial: int,
number_packed_samples: int,
context_seq_len: int,
context_embeddings_dim: int,
) dict#
bridge.diffusion.data.wan.wan_mock_datamodule._mock_collate_fn(**kwargs)#

Return a picklable collate function that calls mock_batch with fixed kwargs.

bridge.diffusion.data.wan.wan_mock_datamodule._collate_ignore_samples(_samples, **kwargs)#

Collate function that ignores samples and delegates to mock_batch.

class bridge.diffusion.data.wan.wan_mock_datamodule.WanMockDataModuleConfig#

Bases: megatron.bridge.data.utils.DatasetProvider

path: str = <Multiline-String>#
seq_length: int#

None

packing_buffer_size: int#

None

micro_batch_size: int#

None

global_batch_size: int#

None

num_workers: int#

None

dataloader_type: str#

‘external’

F_latents: int#

24

H_latents: int#

104

W_latents: int#

60

patch_spatial: int#

2

patch_temporal: int#

1

number_packed_samples: int#

1

context_seq_len: int#

512

context_embeddings_dim: int#

4096

__post_init__()#
build_datasets(
_context: megatron.bridge.data.utils.DatasetBuildContext,
)#