bridge.diffusion.models.wan.utils#
Module Contents#
Functions#
Compute the (f,h,w) output spatial/temporal dimensions of a Conv3d patch embedder. |
|
Convert a list of reconstructed video tensor into patch embeddings.
This method is the inverse of |
|
Reconstruct video tensors from patch embeddings into a list of videotensors. |
|
Split a THD-packed tensor across CP ranks for inputs shaped [S, B, …]. |
API#
- bridge.diffusion.models.wan.utils.grid_sizes_calculation(
- input_shape: Tuple[int, int, int],
- patch_size: Tuple[int, int, int],
Compute the (f,h,w) output spatial/temporal dimensions of a Conv3d patch embedder.
- bridge.diffusion.models.wan.utils.patchify(x, patch_size)#
Convert a list of reconstructed video tensor into patch embeddings. This method is the inverse of
unpatchify.- Parameters:
x (list[torch.Tensor]) – list of tensors, each with shape [c, F_patches * pF, H_patches * pH, W_patches * pW]
patch_size (tuple) – (pF, pH, pW)
- Returns:
shape [ (F_patches * H_patches * W_patches), (c * pF * pH * pW)],
- Return type:
torch.Tensor
- bridge.diffusion.models.wan.utils.unpatchify(
- x: list[torch.Tensor],
- grid_sizes: list[Tuple[int, int, int]],
- out_dim: int,
- patch_size: Tuple[int, int, int],
Reconstruct video tensors from patch embeddings into a list of videotensors.
- Parameters:
x (list[torch.Tensor]) – list of tensors, each with shape [seq_len, c * pF * pH * pW]
grid_sizes (list[Tuple[int, int, int]]) – list of tensors, each with original spatial-temporal grid dimensions before patching, (3 dimensions correspond to F_patches, H_patches, W_patches)
- Returns:
list of tensors, each with shape [c, F_latents, H_latents, W_latents]
- Return type:
list[torch.Tensor]
- bridge.diffusion.models.wan.utils.thd_split_inputs_cp(
- x: torch.Tensor,
- cu_seqlens_q_padded: torch.Tensor,
- cp_group: torch.distributed.ProcessGroup,
Split a THD-packed tensor across CP ranks for inputs shaped [S, B, …].
- Parameters:
x – [S, B, …] tensor (sequence first).
cu_seqlens_q_padded – 1D int32 THD cu_seqlens (padded) used for packing.
cp_group – context-parallel process group.
- Returns:
[S_local, B, …] shard for this CP rank.
- Return type:
x_local