bridge.diffusion.models.wan.utils#

Module Contents#

Functions#

grid_sizes_calculation

Compute the (f,h,w) output spatial/temporal dimensions of a Conv3d patch embedder.

patchify

Convert a list of reconstructed video tensor into patch embeddings. This method is the inverse of unpatchify.

unpatchify

Reconstruct video tensors from patch embeddings into a list of videotensors.

thd_split_inputs_cp

Split a THD-packed tensor across CP ranks for inputs shaped [S, B, …].

API#

bridge.diffusion.models.wan.utils.grid_sizes_calculation(
input_shape: Tuple[int, int, int],
patch_size: Tuple[int, int, int],
) Tuple[int, int, int]#

Compute the (f,h,w) output spatial/temporal dimensions of a Conv3d patch embedder.

bridge.diffusion.models.wan.utils.patchify(x, patch_size)#

Convert a list of reconstructed video tensor into patch embeddings. This method is the inverse of unpatchify.

Parameters:
  • x (list[torch.Tensor]) – list of tensors, each with shape [c, F_patches * pF, H_patches * pH, W_patches * pW]

  • patch_size (tuple) – (pF, pH, pW)

Returns:

shape [ (F_patches * H_patches * W_patches), (c * pF * pH * pW)],

Return type:

torch.Tensor

bridge.diffusion.models.wan.utils.unpatchify(
x: list[torch.Tensor],
grid_sizes: list[Tuple[int, int, int]],
out_dim: int,
patch_size: Tuple[int, int, int],
) list[torch.Tensor]#

Reconstruct video tensors from patch embeddings into a list of videotensors.

Parameters:
  • x (list[torch.Tensor]) – list of tensors, each with shape [seq_len, c * pF * pH * pW]

  • grid_sizes (list[Tuple[int, int, int]]) – list of tensors, each with original spatial-temporal grid dimensions before patching, (3 dimensions correspond to F_patches, H_patches, W_patches)

Returns:

list of tensors, each with shape [c, F_latents, H_latents, W_latents]

Return type:

list[torch.Tensor]

bridge.diffusion.models.wan.utils.thd_split_inputs_cp(
x: torch.Tensor,
cu_seqlens_q_padded: torch.Tensor,
cp_group: torch.distributed.ProcessGroup,
) torch.Tensor#

Split a THD-packed tensor across CP ranks for inputs shaped [S, B, …].

Parameters:
  • x – [S, B, …] tensor (sequence first).

  • cu_seqlens_q_padded – 1D int32 THD cu_seqlens (padded) used for packing.

  • cp_group – context-parallel process group.

Returns:

[S_local, B, …] shard for this CP rank.

Return type:

x_local