nemo_automodel.components.flow_matching.adapters.flux2

View as Markdown

Flux2 model adapter for FlowMatching Pipeline.

Supports FLUX.2-dev with:

  • Mistral3 text embeddings (text_embeddings, shape [B, seq, 15360])
  • No CLIP pooled projections
  • Patchified + BN-normalized image latents ([B, 128, H/16, W/16]) stored by the preprocessor, noise is added in this space
  • 4D positional IDs (T, H, W, L) instead of Flux1’s 3D (batch, y, x)

Module Contents

Classes

NameDescription
Flux2AdapterModel adapter for FLUX.2-dev image generation models.

API

class nemo_automodel.components.flow_matching.adapters.flux2.Flux2Adapter(
guidance_scale: float = 3.5,
use_guidance_embeds: bool = True
)

Bases: ModelAdapter

Model adapter for FLUX.2-dev image generation models.

The preprocessor stores latents already patchified (2x2 spatial patch → channel) and BN-normalized using vae.bn running statistics. The FlowMatchingPipeline adds noise in that space, so this adapter only needs to flatten the spatial dims (pack) before the transformer call and reshape back (unpack) after.

Batch format expected from multiresolution dataloader:

  • image_latents: patchified + BN-normalized latents [B, 128, H_p, W_p]
  • text_embeddings: Mistral3 stacked embeddings [B, seq_len, 15360]

FLUX.2 transformer forward interface:

  • hidden_states: packed latents [B, H_p*W_p, 128]
  • encoder_hidden_states: text embeddings [B, seq_len, 15360]
  • timestep: normalized [0, 1]
  • img_ids: 4D spatial coords [B, H_p*W_p, 4]
  • txt_ids: 4D text coords [B, seq_len, 4]
  • guidance: guidance scale [B]
nemo_automodel.components.flow_matching.adapters.flux2.Flux2Adapter._pack_latents(
latents: torch.Tensor
) -> torch.Tensor

Flatten spatial dims: [B, C, H, W] -> [B, H*W, C].

nemo_automodel.components.flow_matching.adapters.flux2.Flux2Adapter._prepare_latent_ids(
h_p: int,
w_p: int,
batch_size: int,
device: torch.device
) -> torch.Tensor

Build 4D image positional IDs with coords (T=0, h, w, L=0).

Returns [B, H_p*W_p, 4] long tensor.

nemo_automodel.components.flow_matching.adapters.flux2.Flux2Adapter._prepare_text_ids(
seq_len: int,
batch_size: int,
device: torch.device
) -> torch.Tensor

Build 4D text positional IDs with coords (T=0, H=0, W=0, position).

Returns [B, seq_len, 4] long tensor.

nemo_automodel.components.flow_matching.adapters.flux2.Flux2Adapter._unpack_latents(
latents: torch.Tensor,
h: int,
w: int
) -> torch.Tensor

Restore spatial dims: [B, seq, C] -> [B, C, H, W].

nemo_automodel.components.flow_matching.adapters.flux2.Flux2Adapter.forward(
model: torch.nn.Module,
inputs: typing.Dict[str, typing.Any]
) -> torch.Tensor

Execute FLUX.2 transformer forward pass.

Returns prediction in [B, 128, H_p, W_p] — same space as the stored patchified latents, so MSE loss against (noise - latents) is correct.

nemo_automodel.components.flow_matching.adapters.flux2.Flux2Adapter.prepare_inputs(
context: nemo_automodel.components.flow_matching.adapters.base.FlowMatchingContext
) -> typing.Dict[str, typing.Any]

Prepare inputs for FLUX.2 transformer from FlowMatchingContext.

Expects 4D patchified+BN-normalized latents [B, 128, H_p, W_p].