> For clean Markdown of any page, append .md to the page URL.
> For a complete documentation index, see https://docs.nvidia.com/nemo/automodel/llms.txt.
> For AI client integration (Claude Code, Cursor, etc.), connect to the MCP server at https://docs.nvidia.com/nemo/automodel/_mcp/server.

# nemo_automodel.components.flow_matching.adapters.flux2

Flux2 model adapter for FlowMatching Pipeline.

Supports FLUX.2-dev with:

* Mistral3 text embeddings (text\_embeddings, shape \[B, seq, 15360])
* No CLIP pooled projections
* Patchified + BN-normalized image latents (\[B, 128, H/16, W/16])
  stored by the preprocessor, noise is added in this space
* 4D positional IDs (T, H, W, L) instead of Flux1's 3D (batch, y, x)

## Module Contents

### Classes

| Name                                                                                   | Description                                           |
| -------------------------------------------------------------------------------------- | ----------------------------------------------------- |
| [`Flux2Adapter`](#nemo_automodel-components-flow_matching-adapters-flux2-Flux2Adapter) | Model adapter for FLUX.2-dev image generation models. |

### API

```python
class nemo_automodel.components.flow_matching.adapters.flux2.Flux2Adapter(
    guidance_scale: float = 3.5,
    use_guidance_embeds: bool = True
)
```

**Bases:** [ModelAdapter](/nemo-automodel/nemo_automodel/components/flow_matching/adapters/base#nemo_automodel-components-flow_matching-adapters-base-ModelAdapter)

Model adapter for FLUX.2-dev image generation models.

The preprocessor stores latents already patchified (2x2 spatial patch → channel)
and BN-normalized using vae.bn running statistics.  The FlowMatchingPipeline adds
noise in that space, so this adapter only needs to flatten the spatial dims (pack)
before the transformer call and reshape back (unpack) after.

Batch format expected from multiresolution dataloader:

* image\_latents: patchified + BN-normalized latents \[B, 128, H\_p, W\_p]
* text\_embeddings: Mistral3 stacked embeddings \[B, seq\_len, 15360]

FLUX.2 transformer forward interface:

* hidden\_states: packed latents \[B, H\_p\*W\_p, 128]
* encoder\_hidden\_states: text embeddings \[B, seq\_len, 15360]
* timestep: normalized \[0, 1]
* img\_ids: 4D spatial coords \[B, H\_p\*W\_p, 4]
* txt\_ids: 4D text coords \[B, seq\_len, 4]
* guidance: guidance scale \[B]

```python
nemo_automodel.components.flow_matching.adapters.flux2.Flux2Adapter._pack_latents(
    latents: torch.Tensor
) -> torch.Tensor
```

Flatten spatial dims: \[B, C, H, W] -> \[B, H\*W, C].

```python
nemo_automodel.components.flow_matching.adapters.flux2.Flux2Adapter._prepare_latent_ids(
    h_p: int,
    w_p: int,
    batch_size: int,
    device: torch.device
) -> torch.Tensor
```

Build 4D image positional IDs with coords (T=0, h, w, L=0).

Returns \[B, H\_p\*W\_p, 4] long tensor.

```python
nemo_automodel.components.flow_matching.adapters.flux2.Flux2Adapter._prepare_text_ids(
    seq_len: int,
    batch_size: int,
    device: torch.device
) -> torch.Tensor
```

Build 4D text positional IDs with coords (T=0, H=0, W=0, position).

Returns \[B, seq\_len, 4] long tensor.

```python
nemo_automodel.components.flow_matching.adapters.flux2.Flux2Adapter._unpack_latents(
    latents: torch.Tensor,
    h: int,
    w: int
) -> torch.Tensor
```

Restore spatial dims: \[B, seq, C] -> \[B, C, H, W].

```python
nemo_automodel.components.flow_matching.adapters.flux2.Flux2Adapter.forward(
    model: torch.nn.Module,
    inputs: typing.Dict[str, typing.Any]
) -> torch.Tensor
```

Execute FLUX.2 transformer forward pass.

Returns prediction in \[B, 128, H\_p, W\_p] — same space as the stored
patchified latents, so MSE loss against (noise - latents) is correct.

```python
nemo_automodel.components.flow_matching.adapters.flux2.Flux2Adapter.prepare_inputs(
    context: nemo_automodel.components.flow_matching.adapters.base.FlowMatchingContext
) -> typing.Dict[str, typing.Any]
```

Prepare inputs for FLUX.2 transformer from FlowMatchingContext.

Expects 4D patchified+BN-normalized latents \[B, 128, H\_p, W\_p].