nemo_automodel.components.flow_matching.adapters.flux#
Flux model adapter for FlowMatching Pipeline.
This adapter supports FLUX.1 style models with:
T5 text embeddings (text_embeddings)
CLIP pooled embeddings (pooled_prompt_embeds)
2D image latents (treated as 1-frame video: [B, C, 1, H, W])
Module Contents#
Classes#
Model adapter for FLUX.1 image generation models. |
API#
- class nemo_automodel.components.flow_matching.adapters.flux.FluxAdapter(
- guidance_scale: float = 3.5,
- use_guidance_embeds: bool = True,
Bases:
nemo_automodel.components.flow_matching.adapters.base.ModelAdapterModel adapter for FLUX.1 image generation models.
Supports batch format from multiresolution dataloader:
image_latents: [B, C, H, W] for images
text_embeddings: T5 embeddings [B, seq_len, 4096]
pooled_prompt_embeds: CLIP pooled [B, 768]
FLUX model forward interface:
hidden_states: Packed latents
encoder_hidden_states: T5 text embeddings
pooled_projections: CLIP pooled embeddings
timestep: Normalized timesteps [0, 1]
img_ids / txt_ids: Positional embeddings
Initialization
Initialize FluxAdapter.
- Parameters:
guidance_scale – Guidance scale for classifier-free guidance
use_guidance_embeds – Whether to use guidance embeddings
- _pack_latents(latents: torch.Tensor) torch.Tensor#
Pack latents from [B, C, H, W] to Flux format [B, (H//2)(W//2), C4].
Flux uses a 2x2 patch embedding, so latents are reshaped accordingly.
- static _unpack_latents(
- latents: torch.Tensor,
- height: int,
- width: int,
- vae_scale_factor: int = 8,
Unpack latents from Flux format back to [B, C, H, W].
- Parameters:
latents – Packed latents of shape [B, num_patches, channels]
height – Original image height in pixels
width – Original image width in pixels
vae_scale_factor – VAE compression factor (default: 8)
- _prepare_latent_image_ids(
- batch_size: int,
- height: int,
- width: int,
- device: torch.device,
- dtype: torch.dtype,
Prepare positional IDs for image latents.
Returns tensor of shape [B, (H//2)*(W//2), 3] containing (batch_idx, y, x).
- prepare_inputs( ) Dict[str, Any]#
Prepare inputs for Flux model from FlowMatchingContext.
Expects 4D image latents: [B, C, H, W]
- forward(
- model: torch.nn.Module,
- inputs: Dict[str, Any],
Execute forward pass for Flux model.
Returns unpacked prediction in [B, C, H, W] format.