nemo_automodel.components.flow_matching.adapters.flux

View as Markdown

Flux model adapter for FlowMatching Pipeline.

This adapter supports FLUX.1 style models with:

  • T5 text embeddings (text_embeddings)
  • CLIP pooled embeddings (pooled_prompt_embeds)
  • 2D image latents (treated as 1-frame video: [B, C, 1, H, W])

Module Contents

Classes

NameDescription
FluxAdapterModel adapter for FLUX.1 image generation models.

API

class nemo_automodel.components.flow_matching.adapters.flux.FluxAdapter(
guidance_scale: float = 3.5,
use_guidance_embeds: bool = True
)

Bases: ModelAdapter

Model adapter for FLUX.1 image generation models.

Supports batch format from multiresolution dataloader:

  • image_latents: [B, C, H, W] for images
  • text_embeddings: T5 embeddings [B, seq_len, 4096]
  • pooled_prompt_embeds: CLIP pooled [B, 768]

FLUX model forward interface:

  • hidden_states: Packed latents
  • encoder_hidden_states: T5 text embeddings
  • pooled_projections: CLIP pooled embeddings
  • timestep: Normalized timesteps [0, 1]
  • img_ids / txt_ids: Positional embeddings
nemo_automodel.components.flow_matching.adapters.flux.FluxAdapter._pack_latents(
latents: torch.Tensor
) -> torch.Tensor

Pack latents from [B, C, H, W] to Flux format [B, (H//2)(W//2), C4].

Flux uses a 2x2 patch embedding, so latents are reshaped accordingly.

nemo_automodel.components.flow_matching.adapters.flux.FluxAdapter._prepare_latent_image_ids(
batch_size: int,
height: int,
width: int,
device: torch.device,
dtype: torch.dtype
) -> torch.Tensor

Prepare positional IDs for image latents.

Returns tensor of shape [B, (H//2)*(W//2), 3] containing (batch_idx, y, x).

nemo_automodel.components.flow_matching.adapters.flux.FluxAdapter._unpack_latents(
latents: torch.Tensor,
height: int,
width: int,
vae_scale_factor: int = 8
) -> torch.Tensor
staticmethod

Unpack latents from Flux format back to [B, C, H, W].

Parameters:

latents
torch.Tensor

Packed latents of shape [B, num_patches, channels]

height
int

Original image height in pixels

width
int

Original image width in pixels

vae_scale_factor
intDefaults to 8

VAE compression factor (default: 8)

nemo_automodel.components.flow_matching.adapters.flux.FluxAdapter.forward(
model: torch.nn.Module,
inputs: typing.Dict[str, typing.Any]
) -> torch.Tensor

Execute forward pass for Flux model.

Returns unpacked prediction in [B, C, H, W] format.

nemo_automodel.components.flow_matching.adapters.flux.FluxAdapter.prepare_inputs(
context: nemo_automodel.components.flow_matching.adapters.base.FlowMatchingContext
) -> typing.Dict[str, typing.Any]

Prepare inputs for Flux model from FlowMatchingContext.

Expects 4D image latents: [B, C, H, W]