bridge.diffusion.models.flux.flow_matching.flux_inference_pipeline#

FLUX inference pipeline for text-to-image generation.

Module Contents#

Classes#

T5Config

T5 encoder configuration.

ClipConfig

CLIP encoder configuration.

FlowMatchEulerDiscreteScheduler

Euler scheduler.

FluxInferencePipeline

FLUX inference pipeline for text-to-image generation.

API#

class bridge.diffusion.models.flux.flow_matching.flux_inference_pipeline.T5Config#

T5 encoder configuration.

version: Optional[str]#

‘field(…)’

max_length: Optional[int]#

‘field(…)’

load_config_only: bool#

False

device: str#

‘cuda’

class bridge.diffusion.models.flux.flow_matching.flux_inference_pipeline.ClipConfig#

CLIP encoder configuration.

version: Optional[str]#

‘field(…)’

max_length: Optional[int]#

‘field(…)’

always_return_pooled: Optional[bool]#

‘field(…)’

device: str#

‘cuda’

class bridge.diffusion.models.flux.flow_matching.flux_inference_pipeline.FlowMatchEulerDiscreteScheduler(
num_train_timesteps: int = 1000,
shift: float = 1.0,
use_dynamic_shifting=False,
base_shift: Optional[float] = 0.5,
max_shift: Optional[float] = 1.15,
base_image_seq_len: Optional[int] = 256,
max_image_seq_len: Optional[int] = 4096,
)#

Euler scheduler.

Parameters:
  • num_train_timesteps (int, defaults to 1000) – The number of diffusion steps to train the model.

  • timestep_spacing (str, defaults to "linspace") – The way the timesteps should be scaled. Refer to Table 2 of the Common Diffusion Noise Schedules and Sample Steps are Flawed for more information.

  • shift (float, defaults to 1.0) – The shift value for the timestep schedule.

Initialization

_compatibles#

[]

order#

1

property step_index#

The index counter for current timestep. It will increase 1 after each scheduler step.

property begin_index#

The index for the first timestep. It should be set from pipeline with set_begin_index method.

set_begin_index(begin_index: int = 0)#

Sets the begin index for the scheduler. This function should be run from pipeline before the inference.

Parameters:

begin_index (int) – The begin index for the scheduler.

scale_noise(
sample: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
noise: Optional[torch.FloatTensor] = None,
) torch.FloatTensor#

Forward process in flow-matching

Parameters:
  • sample (torch.FloatTensor) – The input sample.

  • timestep (int, optional) – The current timestep in the diffusion chain.

Returns:

A scaled input sample.

Return type:

torch.FloatTensor

_sigma_to_t(sigma)#
time_shift(mu: float, sigma: float, t: torch.Tensor)#
set_timesteps(
num_inference_steps: int = None,
device: Union[str, torch.device] = None,
sigmas: Optional[List[float]] = None,
mu: Optional[float] = None,
)#

Sets the discrete timesteps used for the diffusion chain (to be run before inference).

Parameters:
  • num_inference_steps (int) – The number of diffusion steps used when generating samples with a pre-trained model.

  • device (str or torch.device, optional) – The device to which the timesteps should be moved to. If None, the timesteps are not moved.

index_for_timestep(timestep, schedule_timesteps=None)#
_init_step_index(timestep)#
step(
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
s_churn: float = 0.0,
s_tmin: float = 0.0,
s_tmax: float = float('inf'),
s_noise: float = 1.0,
generator: Optional[torch.Generator] = None,
) Tuple#

Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise).

Parameters:
  • model_output (torch.FloatTensor) – The direct output from learned diffusion model.

  • timestep (float) – The current discrete timestep in the diffusion chain.

  • sample (torch.FloatTensor) – A current instance of a sample created by the diffusion process.

  • s_churn (float)

  • s_tmin (float)

  • s_tmax (float)

  • s_noise (float, defaults to 1.0) – Scaling factor for noise added to the sample.

  • generator (torch.Generator, optional) – A random number generator.

Returns:

A tuple is returned where the first element is the sample tensor.

__len__()#
class bridge.diffusion.models.flux.flow_matching.flux_inference_pipeline.FluxInferencePipeline(
flux_checkpoint_dir: Optional[str] = None,
t5_checkpoint_dir: Optional[str] = None,
clip_checkpoint_dir: Optional[str] = None,
vae_checkpoint_dir: Optional[str] = None,
scheduler_steps: int = 1000,
)#

Bases: torch.nn.Module

FLUX inference pipeline for text-to-image generation.

This pipeline orchestrates the full inference process including:

  • Text encoding with T5 and CLIP

  • Latent preparation and denoising

  • VAE decoding to images

Parameters:
  • params – FluxModelParams configuration.

  • flux – Optional pre-initialized Flux model.

  • scheduler_steps – Number of scheduler steps.

.. rubric:: Example

params = FluxModelParams() pipeline = FluxInferencePipeline(params) pipeline.load_from_pretrained(“path/to/flux_ckpt”) images = pipeline( … prompt=[“A cat holding a sign that says hello world”], … height=1024, … width=1024, … num_inference_steps=20, … )

Initialization

setup_model_from_checkpoint(checkpoint_dir)#
load_text_encoders(t5_version: str = None, clip_version: str = None)#

Load T5 and CLIP text encoders.

Parameters:
  • t5_version – HuggingFace model ID or path for T5.

  • clip_version – HuggingFace model ID or path for CLIP.

load_vae(vae_path: str)#

Load VAE from checkpoint.

Parameters:

vae_path – Path to VAE checkpoint (ae.safetensors).

encode_prompt(
prompt: Union[str, List[str]],
max_sequence_length: int = 512,
num_images_per_prompt: int = 1,
device: str = 'cuda',
dtype: torch.dtype = torch.float32,
)#

Encode text prompts using T5 and CLIP.

Returns:

Tuple of (prompt_embeds, pooled_prompt_embeds, text_ids).

static _prepare_latent_image_ids(
batch_size: int,
height: int,
width: int,
device: torch.device,
dtype: torch.dtype,
)#

Prepare latent image IDs for position encoding.

static _pack_latents(
latents,
batch_size,
num_channels_latents,
height,
width,
)#

Pack latents for FLUX processing.

static _unpack_latents(latents, height, width, vae_scale_factor)#

Unpack latents for VAE decoding.

static _calculate_shift(
image_seq_len,
base_seq_len=256,
max_seq_len=4096,
base_shift=0.5,
max_shift=1.16,
)#

Calculate timestep shift based on sequence length.

prepare_latents(
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator=None,
)#

Prepare random latents for generation.

__call__(
prompt: Union[str, List[str]],
height: int = 1024,
width: int = 1024,
num_inference_steps: int = 10,
guidance_scale: float = 3.5,
num_images_per_prompt: int = 1,
generator: Optional[torch.Generator] = None,
max_sequence_length: int = 512,
output_type: str = 'pil',
output_path: Optional[str] = None,
dtype: torch.dtype = torch.bfloat16,
)#

Generate images from text prompts.

Parameters:
  • prompt – Text prompt(s) for image generation.

  • height – Output image height.

  • width – Output image width.

  • num_inference_steps – Number of denoising steps.

  • guidance_scale – Classifier-free guidance scale.

  • num_images_per_prompt – Number of images per prompt.

  • generator – Random number generator for reproducibility.

  • max_sequence_length – Maximum sequence length for text encoding.

  • output_type – “pil” for PIL images, “latent” for latent tensors.

  • output_path – Path to save generated images.

  • dtype – Data type for inference.

Returns:

List of PIL images or latent tensors.

static numpy_to_pil(images)#

Convert a numpy image or a batch of images to a PIL image.

static torch_to_numpy(images)#

Convert a torch image or a batch of images to a numpy image.

static denormalize(image)#