bridge.diffusion.models.wan.flow_matching.flow_matching_pipeline_wan#

Module Contents#

Classes#

WanAdapter

Model adapter for Wan model (Megatron version).

WanFlowMatchingPipeline

Wan-specific Flow Matching pipeline handling Context Parallelism and Custom Noise.

API#

class bridge.diffusion.models.wan.flow_matching.flow_matching_pipeline_wan.WanAdapter#

Bases: megatron.bridge.diffusion.common.flow_matching.adapters.base.ModelAdapter

Model adapter for Wan model (Megatron version).

Handles mapping of standard FlowMatchingContext to Wan specific inputs.

prepare_inputs(
context: megatron.bridge.diffusion.common.flow_matching.adapters.base.FlowMatchingContext,
) Dict[str, Any]#
forward(
model: torch.nn.Module,
inputs: Dict[str, Any],
) torch.Tensor#

Execute forward pass for Wan model.

Parameters:
  • model – Wan model

  • inputs – Dictionary from prepare_inputs()

Returns:

Model prediction tensor

class bridge.diffusion.models.wan.flow_matching.flow_matching_pipeline_wan.WanFlowMatchingPipeline#

Bases: megatron.bridge.diffusion.common.flow_matching.flow_matching_pipeline.FlowMatchingPipeline

Wan-specific Flow Matching pipeline handling Context Parallelism and Custom Noise.

This pipeline extends the standard FlowMatchingPipeline to support:

  1. Wan-specific noise generation (patching + padding)

  2. Context Parallelism (CP) splitting of inputs

  3. Masked loss computation

determine_task_type(data_type: str) str#

Determine task type based on data type and randomization.

compute_loss(
model_pred: torch.Tensor,
target: torch.Tensor,
sigma: torch.Tensor,
batch: Dict[str, Any],
) Tuple[torch.Tensor, torch.Tensor, torch.Tensor]#