bridge.diffusion.models.wan.flow_matching.flow_inference_pipeline#

Module Contents#

Classes#

Functions#

API#

bridge.diffusion.models.wan.flow_matching.flow_inference_pipeline._encode_text(
tokenizer: transformers.AutoTokenizer,
text_encoder: transformers.UMT5EncoderModel,
device: str,
caption: str,
) torch.Tensor#
class bridge.diffusion.models.wan.flow_matching.flow_inference_pipeline.FlowInferencePipeline(
inference_cfg,
model_id='Wan-AI/Wan2.1-T2V-14B-Diffusers',
checkpoint_dir=None,
checkpoint_step=None,
t5_checkpoint_dir=None,
vae_checkpoint_dir=None,
device_id=0,
rank=0,
t5_cpu=False,
tensor_parallel_size=1,
context_parallel_size=1,
pipeline_parallel_size=1,
sequence_parallel=False,
pipeline_dtype=torch.float32,
)#

Initialization

Initializes the FlowInferencePipeline with the given parameters.

Parameters:
  • inference_cfg (dict) – Object containing inference configuration.

  • checkpoint_dir (str) – Path to directory containing model checkpoints

  • t5_checkpoint_dir (str, optional, defaults to None) – Optional directory containing T5 checkpoint and tokenizer; falls back to checkpoint_dir if None.

  • vae_checkpoint_dir (str, optional, defaults to None) – Optional directory containing VAE checkpoint; falls back to checkpoint_dir if None.

  • device_id (int, optional, defaults to 0) – Id of target GPU device

  • rank (int, optional, defaults to 0) – Process rank for distributed training

  • t5_cpu (bool, optional, defaults to False) – Whether to place T5 model on CPU. Only works without t5_fsdp.

setup_model_from_checkpoint(checkpoint_dir)#
_select_checkpoint_dir(base_dir: str, checkpoint_step) str#

Resolve checkpoint directory:

  • If checkpoint_step is provided, use base_dir/iter_{step:07d}

  • Otherwise, pick the largest iter_######## subdirectory under base_dir

forward_pp_step(
latent_model_input: torch.Tensor,
grid_sizes: list[Tuple[int, int, int]],
max_video_seq_len: int,
timestep: torch.Tensor,
arg_c: dict,
) torch.Tensor#

Forward pass supporting pipeline parallelism.

generate(
prompts,
sizes,
frame_nums,
shift=5.0,
sampling_steps=50,
guide_scale=5.0,
n_prompt='',
seed=-1,
offload_model=True,
)#

Generates video frames from text prompt using diffusion process.

Parameters:
  • prompts (list[str]) – Text prompt for content generation

  • sizes (list[tuple[int, int]]) – Controls video resolution, (width,height).

  • frame_nums (list[int]) – How many frames to sample from a video. The number should be 4n+1

  • shift (float, optional, defaults to 5.0) – Noise schedule shift parameter. Affects temporal dynamics

  • sampling_steps (int, optional, defaults to 40) – Number of diffusion sampling steps. Higher values improve quality but slow generation

  • guide_scale (float, optional, defaults 5.0) – Classifier-free guidance scale. Controls prompt adherence vs. creativity

  • n_prompt (str, optional, defaults to “”) – Negative prompt for content exclusion. If not given, use config.sample_neg_prompt

  • seed (int, optional, defaults to -1) – Random seed for noise generation. If -1, use random seed.

  • offload_model (bool, optional, defaults to True) – If True, offloads models to CPU during generation to save VRAM

Returns:

Generated video frames tensor. Dimensions: (C, N H, W) where:

  • C: Color channels (3 for RGB)

  • N: Number of frames (81)

  • H: Frame height (from size)

  • W: Frame width from size)

Return type:

torch.Tensor