bridge.diffusion.models.wan.flow_matching.flow_matching_pipeline_wan#
Module Contents#
Classes#
Model adapter for Wan model (Megatron version). |
|
Wan-specific Flow Matching pipeline handling Context Parallelism and Custom Noise. |
API#
- class bridge.diffusion.models.wan.flow_matching.flow_matching_pipeline_wan.WanAdapter#
Bases:
megatron.bridge.diffusion.common.flow_matching.adapters.base.ModelAdapterModel adapter for Wan model (Megatron version).
Handles mapping of standard FlowMatchingContext to Wan specific inputs.
- prepare_inputs(
- context: megatron.bridge.diffusion.common.flow_matching.adapters.base.FlowMatchingContext,
- forward(
- model: torch.nn.Module,
- inputs: Dict[str, Any],
Execute forward pass for Wan model.
- Parameters:
model – Wan model
inputs – Dictionary from prepare_inputs()
- Returns:
Model prediction tensor
- class bridge.diffusion.models.wan.flow_matching.flow_matching_pipeline_wan.WanFlowMatchingPipeline#
Bases:
megatron.bridge.diffusion.common.flow_matching.flow_matching_pipeline.FlowMatchingPipelineWan-specific Flow Matching pipeline handling Context Parallelism and Custom Noise.
This pipeline extends the standard FlowMatchingPipeline to support:
Wan-specific noise generation (patching + padding)
Context Parallelism (CP) splitting of inputs
Masked loss computation
- determine_task_type(data_type: str) str#
Determine task type based on data type and randomization.
- compute_loss(
- model_pred: torch.Tensor,
- target: torch.Tensor,
- sigma: torch.Tensor,
- batch: Dict[str, Any],