nemo_automodel.components.flow_matching.adapters.flux#

Flux model adapter for FlowMatching Pipeline.

This adapter supports FLUX.1 style models with:

  • T5 text embeddings (text_embeddings)

  • CLIP pooled embeddings (pooled_prompt_embeds)

  • 2D image latents (treated as 1-frame video: [B, C, 1, H, W])

Module Contents#

Classes#

FluxAdapter

Model adapter for FLUX.1 image generation models.

API#

class nemo_automodel.components.flow_matching.adapters.flux.FluxAdapter(
guidance_scale: float = 3.5,
use_guidance_embeds: bool = True,
)#

Bases: nemo_automodel.components.flow_matching.adapters.base.ModelAdapter

Model adapter for FLUX.1 image generation models.

Supports batch format from multiresolution dataloader:

  • image_latents: [B, C, H, W] for images

  • text_embeddings: T5 embeddings [B, seq_len, 4096]

  • pooled_prompt_embeds: CLIP pooled [B, 768]

FLUX model forward interface:

  • hidden_states: Packed latents

  • encoder_hidden_states: T5 text embeddings

  • pooled_projections: CLIP pooled embeddings

  • timestep: Normalized timesteps [0, 1]

  • img_ids / txt_ids: Positional embeddings

Initialization

Initialize FluxAdapter.

Parameters:
  • guidance_scale – Guidance scale for classifier-free guidance

  • use_guidance_embeds – Whether to use guidance embeddings

_pack_latents(latents: torch.Tensor) torch.Tensor#

Pack latents from [B, C, H, W] to Flux format [B, (H//2)(W//2), C4].

Flux uses a 2x2 patch embedding, so latents are reshaped accordingly.

static _unpack_latents(
latents: torch.Tensor,
height: int,
width: int,
vae_scale_factor: int = 8,
) torch.Tensor#

Unpack latents from Flux format back to [B, C, H, W].

Parameters:
  • latents – Packed latents of shape [B, num_patches, channels]

  • height – Original image height in pixels

  • width – Original image width in pixels

  • vae_scale_factor – VAE compression factor (default: 8)

_prepare_latent_image_ids(
batch_size: int,
height: int,
width: int,
device: torch.device,
dtype: torch.dtype,
) torch.Tensor#

Prepare positional IDs for image latents.

Returns tensor of shape [B, (H//2)*(W//2), 3] containing (batch_idx, y, x).

prepare_inputs(
context: nemo_automodel.components.flow_matching.adapters.base.FlowMatchingContext,
) Dict[str, Any]#

Prepare inputs for Flux model from FlowMatchingContext.

Expects 4D image latents: [B, C, H, W]

forward(
model: torch.nn.Module,
inputs: Dict[str, Any],
) torch.Tensor#

Execute forward pass for Flux model.

Returns unpacked prediction in [B, C, H, W] format.