bridge.diffusion.models.flux.flux_model#
FLUX diffusion model implementation with Megatron Core.
Module Contents#
Classes#
FLUX diffusion model implementation with Megatron Core. |
API#
- class bridge.diffusion.models.flux.flux_model.Flux(
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- pre_process: bool = True,
- post_process: bool = True,
- fp16_lm_cross_entropy: bool = False,
- parallel_output: bool = True,
- **kwargs,
Bases:
megatron.core.models.common.vision_module.vision_module.VisionModuleFLUX diffusion model implementation with Megatron Core.
FLUX is a state-of-the-art text-to-image diffusion model that uses a combination of double (MMDiT-style) and single transformer blocks.
- Parameters:
config – FluxProvider containing model hyperparameters.
.. attribute:: out_channels
Number of output channels.
.. attribute:: hidden_size
Hidden dimension size.
.. attribute:: num_attention_heads
Number of attention heads.
.. attribute:: patch_size
Patch size for image embedding.
.. attribute:: in_channels
Number of input channels.
.. attribute:: guidance_embed
Whether guidance embedding is used.
.. attribute:: pos_embed
N-dimensional position embedding module.
.. attribute:: img_embed
Image embedding linear layer.
.. attribute:: txt_embed
Text embedding linear layer.
.. attribute:: timestep_embedding
Timestep embedding module.
.. attribute:: vector_embedding
Vector (CLIP pooled) embedding module.
.. attribute:: guidance_embedding
Guidance embedding module (if guidance_embed=True).
.. attribute:: double_blocks
List of MMDiT layers for double blocks.
.. attribute:: single_blocks
List of single transformer blocks.
.. attribute:: norm_out
Output normalization layer.
.. attribute:: proj_out
Output projection layer.
Initialization
- get_fp8_context()#
Get FP8 autocast context if FP8 is enabled.
- forward(
- img: torch.Tensor,
- txt: torch.Tensor = None,
- y: torch.Tensor = None,
- timesteps: torch.LongTensor = None,
- img_ids: torch.Tensor = None,
- txt_ids: torch.Tensor = None,
- guidance: torch.Tensor = None,
- controlnet_double_block_samples: torch.Tensor = None,
- controlnet_single_block_samples: torch.Tensor = None,
Forward pass through the FLUX model.
- Parameters:
img – Image input tensor (latents) [B, S, C].
txt – Text input tensor (text embeddings) [B, S, D].
y – Vector input for embedding (CLIP pooled output) [B, D].
timesteps – Timestep input tensor [B].
img_ids – Image position IDs for rotary embedding [B, S, 3].
txt_ids – Text position IDs for rotary embedding [B, S, 3].
guidance – Guidance input for conditioning (FLUX-dev) [B].
controlnet_double_block_samples – Optional controlnet samples for double blocks.
controlnet_single_block_samples – Optional controlnet samples for single blocks.
- Returns:
Output tensor of shape [B, S, out_channels].
- sharded_state_dict(
- prefix='',
- sharded_offsets: tuple = (),
- metadata: dict = None,
Get sharded state dict for distributed checkpointing.
- Parameters:
prefix – Prefix for state dict keys.
sharded_offsets – Sharded offsets tuple.
metadata – Additional metadata.
- Returns:
ShardedStateDict for the model.
- _set_embedder_weights_replica_id(
- tensor: torch.Tensor,
- sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
- embedder_weight_key: str,
Set replica IDs of the weights in embedding layers for sharded state dict.
- Parameters:
tensor – The parameter tensor to set replica ID for.
sharded_state_dict – State dict with the weight to tie.
embedder_weight_key – Key of the weight in the state dict.
- Returns:
None, acts in-place.
- set_input_tensor(input_tensor)#
Set input tensor for pipeline parallelism.