bridge.diffusion.models.flux.flow_matching.flux_inference_pipeline#
FLUX inference pipeline for text-to-image generation.
Module Contents#
Classes#
T5 encoder configuration. |
|
CLIP encoder configuration. |
|
Euler scheduler. |
|
FLUX inference pipeline for text-to-image generation. |
API#
- class bridge.diffusion.models.flux.flow_matching.flux_inference_pipeline.T5Config#
T5 encoder configuration.
- version: Optional[str]#
‘field(…)’
- max_length: Optional[int]#
‘field(…)’
- load_config_only: bool#
False
- device: str#
‘cuda’
- class bridge.diffusion.models.flux.flow_matching.flux_inference_pipeline.ClipConfig#
CLIP encoder configuration.
- version: Optional[str]#
‘field(…)’
- max_length: Optional[int]#
‘field(…)’
- always_return_pooled: Optional[bool]#
‘field(…)’
- device: str#
‘cuda’
- class bridge.diffusion.models.flux.flow_matching.flux_inference_pipeline.FlowMatchEulerDiscreteScheduler(
- num_train_timesteps: int = 1000,
- shift: float = 1.0,
- use_dynamic_shifting=False,
- base_shift: Optional[float] = 0.5,
- max_shift: Optional[float] = 1.15,
- base_image_seq_len: Optional[int] = 256,
- max_image_seq_len: Optional[int] = 4096,
Euler scheduler.
- Parameters:
num_train_timesteps (
int, defaults to 1000) – The number of diffusion steps to train the model.timestep_spacing (
str, defaults to"linspace") – The way the timesteps should be scaled. Refer to Table 2 of the Common Diffusion Noise Schedules and Sample Steps are Flawed for more information.shift (
float, defaults to 1.0) – The shift value for the timestep schedule.
Initialization
- _compatibles#
[]
- order#
1
- property step_index#
The index counter for current timestep. It will increase 1 after each scheduler step.
- property begin_index#
The index for the first timestep. It should be set from pipeline with
set_begin_indexmethod.
- set_begin_index(begin_index: int = 0)#
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
- Parameters:
begin_index (
int) – The begin index for the scheduler.
- scale_noise(
- sample: torch.FloatTensor,
- timestep: Union[float, torch.FloatTensor],
- noise: Optional[torch.FloatTensor] = None,
Forward process in flow-matching
- Parameters:
sample (
torch.FloatTensor) – The input sample.timestep (
int, optional) – The current timestep in the diffusion chain.
- Returns:
A scaled input sample.
- Return type:
torch.FloatTensor
- _sigma_to_t(sigma)#
- time_shift(mu: float, sigma: float, t: torch.Tensor)#
- set_timesteps(
- num_inference_steps: int = None,
- device: Union[str, torch.device] = None,
- sigmas: Optional[List[float]] = None,
- mu: Optional[float] = None,
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
- Parameters:
num_inference_steps (
int) – The number of diffusion steps used when generating samples with a pre-trained model.device (
strortorch.device, optional) – The device to which the timesteps should be moved to. IfNone, the timesteps are not moved.
- index_for_timestep(timestep, schedule_timesteps=None)#
- _init_step_index(timestep)#
- step(
- model_output: torch.FloatTensor,
- timestep: Union[float, torch.FloatTensor],
- sample: torch.FloatTensor,
- s_churn: float = 0.0,
- s_tmin: float = 0.0,
- s_tmax: float = float('inf'),
- s_noise: float = 1.0,
- generator: Optional[torch.Generator] = None,
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise).
- Parameters:
model_output (
torch.FloatTensor) – The direct output from learned diffusion model.timestep (
float) – The current discrete timestep in the diffusion chain.sample (
torch.FloatTensor) – A current instance of a sample created by the diffusion process.s_churn (
float)s_tmin (
float)s_tmax (
float)s_noise (
float, defaults to 1.0) – Scaling factor for noise added to the sample.generator (
torch.Generator, optional) – A random number generator.
- Returns:
A tuple is returned where the first element is the sample tensor.
- __len__()#
- class bridge.diffusion.models.flux.flow_matching.flux_inference_pipeline.FluxInferencePipeline(
- flux_checkpoint_dir: Optional[str] = None,
- t5_checkpoint_dir: Optional[str] = None,
- clip_checkpoint_dir: Optional[str] = None,
- vae_checkpoint_dir: Optional[str] = None,
- scheduler_steps: int = 1000,
Bases:
torch.nn.ModuleFLUX inference pipeline for text-to-image generation.
This pipeline orchestrates the full inference process including:
Text encoding with T5 and CLIP
Latent preparation and denoising
VAE decoding to images
- Parameters:
params – FluxModelParams configuration.
flux – Optional pre-initialized Flux model.
scheduler_steps – Number of scheduler steps.
.. rubric:: Example
params = FluxModelParams() pipeline = FluxInferencePipeline(params) pipeline.load_from_pretrained(“path/to/flux_ckpt”) images = pipeline( … prompt=[“A cat holding a sign that says hello world”], … height=1024, … width=1024, … num_inference_steps=20, … )
Initialization
- setup_model_from_checkpoint(checkpoint_dir)#
- load_text_encoders(t5_version: str = None, clip_version: str = None)#
Load T5 and CLIP text encoders.
- Parameters:
t5_version – HuggingFace model ID or path for T5.
clip_version – HuggingFace model ID or path for CLIP.
- load_vae(vae_path: str)#
Load VAE from checkpoint.
- Parameters:
vae_path – Path to VAE checkpoint (ae.safetensors).
- encode_prompt(
- prompt: Union[str, List[str]],
- max_sequence_length: int = 512,
- num_images_per_prompt: int = 1,
- device: str = 'cuda',
- dtype: torch.dtype = torch.float32,
Encode text prompts using T5 and CLIP.
- Returns:
Tuple of (prompt_embeds, pooled_prompt_embeds, text_ids).
- static _prepare_latent_image_ids(
- batch_size: int,
- height: int,
- width: int,
- device: torch.device,
- dtype: torch.dtype,
Prepare latent image IDs for position encoding.
- static _pack_latents(
- latents,
- batch_size,
- num_channels_latents,
- height,
- width,
Pack latents for FLUX processing.
- static _unpack_latents(latents, height, width, vae_scale_factor)#
Unpack latents for VAE decoding.
- static _calculate_shift(
- image_seq_len,
- base_seq_len=256,
- max_seq_len=4096,
- base_shift=0.5,
- max_shift=1.16,
Calculate timestep shift based on sequence length.
- prepare_latents(
- batch_size,
- num_channels_latents,
- height,
- width,
- dtype,
- device,
- generator=None,
Prepare random latents for generation.
- __call__(
- prompt: Union[str, List[str]],
- height: int = 1024,
- width: int = 1024,
- num_inference_steps: int = 10,
- guidance_scale: float = 3.5,
- num_images_per_prompt: int = 1,
- generator: Optional[torch.Generator] = None,
- max_sequence_length: int = 512,
- output_type: str = 'pil',
- output_path: Optional[str] = None,
- dtype: torch.dtype = torch.bfloat16,
Generate images from text prompts.
- Parameters:
prompt – Text prompt(s) for image generation.
height – Output image height.
width – Output image width.
num_inference_steps – Number of denoising steps.
guidance_scale – Classifier-free guidance scale.
num_images_per_prompt – Number of images per prompt.
generator – Random number generator for reproducibility.
max_sequence_length – Maximum sequence length for text encoding.
output_type – “pil” for PIL images, “latent” for latent tensors.
output_path – Path to save generated images.
dtype – Data type for inference.
- Returns:
List of PIL images or latent tensors.
- static numpy_to_pil(images)#
Convert a numpy image or a batch of images to a PIL image.
- static torch_to_numpy(images)#
Convert a torch image or a batch of images to a numpy image.
- static denormalize(image)#