bridge.diffusion.models.flux.flow_matching.flux_adapter#
Megatron-specific adapter for FLUX models using the automodel FlowMatching pipeline.
Module Contents#
Classes#
Adapter for FLUX models in Megatron training framework. |
API#
- class bridge.diffusion.models.flux.flow_matching.flux_adapter.MegatronFluxAdapter(guidance_scale: float = 3.5)#
Bases:
megatron.bridge.diffusion.common.flow_matching.adapters.base.ModelAdapterAdapter for FLUX models in Megatron training framework.
Handles sequence-first tensor layout [S, B, …] required by Megatron
Integrates with pipeline parallelism
Maps Megatron batch keys to expected format
Handles guidance embedding for FLUX-dev models
Initialization
Initialize MegatronFluxAdapter.
- Parameters:
guidance_scale – Guidance scale for classifier-free guidance
- _pack_latents(latents: torch.Tensor) torch.Tensor#
Pack latents from [B, C, H, W] to Flux format [B, (H//2)(W//2), C4].
Flux uses a 2x2 patch embedding, so latents are reshaped accordingly.
- _unpack_latents(
- latents: torch.Tensor,
- height: int,
- width: int,
Unpack latents from Flux format [B, num_patches, C*4] back to [B, C, H, W].
- Parameters:
latents – Packed latents of shape [B, num_patches, channels]
height – Target latent height
width – Target latent width
- _prepare_latent_image_ids(
- batch_size: int,
- height: int,
- width: int,
- device: torch.device,
- dtype: torch.dtype,
Prepare positional IDs for image latents.
Returns tensor of shape [B, (H//2)*(W//2), 3] containing (batch_idx, y, x).
- prepare_inputs(
- context: megatron.bridge.diffusion.common.flow_matching.adapters.base.FlowMatchingContext,
Prepare inputs for Megatron Flux model from FlowMatchingContext.
Handles batch key mapping:
Megatron uses: latents, prompt_embeds, pooled_prompt_embeds, text_ids
Automodel expects: image_latents, text_embeddings, pooled_prompt_embeds
- forward(
- model: megatron.core.models.common.vision_module.vision_module.VisionModule,
- inputs: Dict[str, Any],
Execute forward pass for Megatron Flux model.
Returns unpacked prediction in [B, C, H, W] format.