bridge.diffusion.models.flux.flux_model#

FLUX diffusion model implementation with Megatron Core.

Module Contents#

Classes#

Flux

FLUX diffusion model implementation with Megatron Core.

API#

class bridge.diffusion.models.flux.flux_model.Flux(
config: megatron.core.transformer.transformer_config.TransformerConfig,
pre_process: bool = True,
post_process: bool = True,
fp16_lm_cross_entropy: bool = False,
parallel_output: bool = True,
**kwargs,
)#

Bases: megatron.core.models.common.vision_module.vision_module.VisionModule

FLUX diffusion model implementation with Megatron Core.

FLUX is a state-of-the-art text-to-image diffusion model that uses a combination of double (MMDiT-style) and single transformer blocks.

Parameters:

config – FluxProvider containing model hyperparameters.

.. attribute:: out_channels

Number of output channels.

.. attribute:: hidden_size

Hidden dimension size.

.. attribute:: num_attention_heads

Number of attention heads.

.. attribute:: patch_size

Patch size for image embedding.

.. attribute:: in_channels

Number of input channels.

.. attribute:: guidance_embed

Whether guidance embedding is used.

.. attribute:: pos_embed

N-dimensional position embedding module.

.. attribute:: img_embed

Image embedding linear layer.

.. attribute:: txt_embed

Text embedding linear layer.

.. attribute:: timestep_embedding

Timestep embedding module.

.. attribute:: vector_embedding

Vector (CLIP pooled) embedding module.

.. attribute:: guidance_embedding

Guidance embedding module (if guidance_embed=True).

.. attribute:: double_blocks

List of MMDiT layers for double blocks.

.. attribute:: single_blocks

List of single transformer blocks.

.. attribute:: norm_out

Output normalization layer.

.. attribute:: proj_out

Output projection layer.

Initialization

get_fp8_context()#

Get FP8 autocast context if FP8 is enabled.

forward(
img: torch.Tensor,
txt: torch.Tensor = None,
y: torch.Tensor = None,
timesteps: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
controlnet_double_block_samples: torch.Tensor = None,
controlnet_single_block_samples: torch.Tensor = None,
)#

Forward pass through the FLUX model.

Parameters:
  • img – Image input tensor (latents) [B, S, C].

  • txt – Text input tensor (text embeddings) [B, S, D].

  • y – Vector input for embedding (CLIP pooled output) [B, D].

  • timesteps – Timestep input tensor [B].

  • img_ids – Image position IDs for rotary embedding [B, S, 3].

  • txt_ids – Text position IDs for rotary embedding [B, S, 3].

  • guidance – Guidance input for conditioning (FLUX-dev) [B].

  • controlnet_double_block_samples – Optional controlnet samples for double blocks.

  • controlnet_single_block_samples – Optional controlnet samples for single blocks.

Returns:

Output tensor of shape [B, S, out_channels].

sharded_state_dict(
prefix='',
sharded_offsets: tuple = (),
metadata: dict = None,
) megatron.core.dist_checkpointing.mapping.ShardedStateDict#

Get sharded state dict for distributed checkpointing.

Parameters:
  • prefix – Prefix for state dict keys.

  • sharded_offsets – Sharded offsets tuple.

  • metadata – Additional metadata.

Returns:

ShardedStateDict for the model.

_set_embedder_weights_replica_id(
tensor: torch.Tensor,
sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
embedder_weight_key: str,
) None#

Set replica IDs of the weights in embedding layers for sharded state dict.

Parameters:
  • tensor – The parameter tensor to set replica ID for.

  • sharded_state_dict – State dict with the weight to tie.

  • embedder_weight_key – Key of the weight in the state dict.

Returns:

None, acts in-place.

set_input_tensor(input_tensor)#

Set input tensor for pipeline parallelism.