bridge.diffusion.models.flux.flow_matching.flux_adapter#

Megatron-specific adapter for FLUX models using the automodel FlowMatching pipeline.

Module Contents#

Classes#

MegatronFluxAdapter

Adapter for FLUX models in Megatron training framework.

API#

class bridge.diffusion.models.flux.flow_matching.flux_adapter.MegatronFluxAdapter(guidance_scale: float = 3.5)#

Bases: megatron.bridge.diffusion.common.flow_matching.adapters.base.ModelAdapter

Adapter for FLUX models in Megatron training framework.

  • Handles sequence-first tensor layout [S, B, …] required by Megatron

  • Integrates with pipeline parallelism

  • Maps Megatron batch keys to expected format

  • Handles guidance embedding for FLUX-dev models

Initialization

Initialize MegatronFluxAdapter.

Parameters:

guidance_scale – Guidance scale for classifier-free guidance

_pack_latents(latents: torch.Tensor) torch.Tensor#

Pack latents from [B, C, H, W] to Flux format [B, (H//2)(W//2), C4].

Flux uses a 2x2 patch embedding, so latents are reshaped accordingly.

_unpack_latents(
latents: torch.Tensor,
height: int,
width: int,
) torch.Tensor#

Unpack latents from Flux format [B, num_patches, C*4] back to [B, C, H, W].

Parameters:
  • latents – Packed latents of shape [B, num_patches, channels]

  • height – Target latent height

  • width – Target latent width

_prepare_latent_image_ids(
batch_size: int,
height: int,
width: int,
device: torch.device,
dtype: torch.dtype,
) torch.Tensor#

Prepare positional IDs for image latents.

Returns tensor of shape [B, (H//2)*(W//2), 3] containing (batch_idx, y, x).

prepare_inputs(
context: megatron.bridge.diffusion.common.flow_matching.adapters.base.FlowMatchingContext,
) Dict[str, Any]#

Prepare inputs for Megatron Flux model from FlowMatchingContext.

Handles batch key mapping:

  • Megatron uses: latents, prompt_embeds, pooled_prompt_embeds, text_ids

  • Automodel expects: image_latents, text_embeddings, pooled_prompt_embeds

forward(
model: megatron.core.models.common.vision_module.vision_module.VisionModule,
inputs: Dict[str, Any],
) torch.Tensor#

Execute forward pass for Megatron Flux model.

Returns unpacked prediction in [B, C, H, W] format.