bridge.diffusion.models.flux.flux_layer_spec#

FLUX layer specifications and transformer blocks.

Module Contents#

Classes#

AdaLN

Adaptive Layer Normalization Module for DiT/FLUX models.

AdaLNContinuous

A variant of AdaLN used for FLUX models.

MMDiTLayer

Multi-modal transformer layer for FLUX double blocks.

FluxSingleTransformerBlock

FLUX Single Transformer Block.

Functions#

get_flux_double_transformer_engine_spec

Get the module specification for FLUX double transformer blocks.

get_flux_single_transformer_engine_spec

Get the module specification for FLUX single transformer blocks.

API#

class bridge.diffusion.models.flux.flux_layer_spec.AdaLN(
config: megatron.core.transformer.transformer_config.TransformerConfig,
n_adaln_chunks: int = 9,
norm=nn.LayerNorm,
modulation_bias: bool = False,
use_second_norm: bool = False,
)#

Bases: megatron.core.transformer.module.MegatronModule

Adaptive Layer Normalization Module for DiT/FLUX models.

Implements adaptive layer normalization that conditions on timestep embeddings.

Parameters:
  • config – Transformer configuration.

  • n_adaln_chunks – Number of adaptive LN chunks for modulation outputs.

  • norm – Normalization type to use.

  • modulation_bias – Whether to use bias in modulation layers.

  • use_second_norm – Whether to use a second layer norm.

Initialization

forward(timestep_emb)#

Apply adaptive layer normalization modulation.

modulate(x, shift, scale)#

Apply modulation with shift and scale.

scale_add(residual, x, gate)#

Add gated output to residual.

modulated_layernorm(x, shift, scale, layernorm_idx=0)#

Apply layer norm followed by modulation.

scaled_modulated_layernorm(
residual,
x,
gate,
shift,
scale,
layernorm_idx=0,
)#

Apply scale, add, and modulated layer norm.

class bridge.diffusion.models.flux.flux_layer_spec.AdaLNContinuous(
config: megatron.core.transformer.transformer_config.TransformerConfig,
conditioning_embedding_dim: int,
modulation_bias: bool = True,
norm_type: str = 'layer_norm',
)#

Bases: megatron.core.transformer.module.MegatronModule

A variant of AdaLN used for FLUX models.

Continuous adaptive layer normalization that outputs scale and shift directly from conditioning embeddings.

Parameters:
  • config – Transformer configuration.

  • conditioning_embedding_dim – Dimension of the conditioning embedding.

  • modulation_bias – Whether to use bias in modulation layer.

  • norm_type – Type of normalization (β€œlayer_norm” or β€œrms_norm”).

Initialization

forward(
x: torch.Tensor,
conditioning_embedding: torch.Tensor,
) torch.Tensor#

Apply continuous adaptive layer normalization.

class bridge.diffusion.models.flux.flux_layer_spec.MMDiTLayer(
config: megatron.core.transformer.transformer_config.TransformerConfig,
submodules: megatron.core.transformer.transformer_layer.TransformerLayerSubmodules,
layer_number: int = 1,
context_pre_only: bool = False,
)#

Bases: megatron.core.transformer.transformer_layer.TransformerLayer

Multi-modal transformer layer for FLUX double blocks.

Transformer layer takes input with size [s, b, h] and returns an output of the same size.

MMDiT layer implementation from [https://arxiv.org/pdf/2403.03206].

Parameters:
  • config – Transformer configuration.

  • submodules – Transformer layer submodules.

  • layer_number – Layer index.

  • context_pre_only – Whether to only use context for pre-processing.

Initialization

forward(
hidden_states,
encoder_hidden_states,
attention_mask=None,
context=None,
context_mask=None,
rotary_pos_emb=None,
inference_params=None,
packed_seq_params=None,
emb=None,
)#

Forward pass for MMDiT layer.

Parameters:
  • hidden_states – Image hidden states.

  • encoder_hidden_states – Text encoder hidden states.

  • attention_mask – Attention mask.

  • context – Context tensor (unused).

  • context_mask – Context mask (unused).

  • rotary_pos_emb – Rotary position embeddings.

  • inference_params – Inference parameters.

  • packed_seq_params – Packed sequence parameters.

  • emb – Timestep/conditioning embedding.

Returns:

Tuple of (hidden_states, encoder_hidden_states).

__call__(*args, **kwargs)#
class bridge.diffusion.models.flux.flux_layer_spec.FluxSingleTransformerBlock(
config: megatron.core.transformer.transformer_config.TransformerConfig,
submodules: megatron.core.transformer.transformer_layer.TransformerLayerSubmodules,
layer_number: int = 1,
mlp_ratio: int = 4,
n_adaln_chunks: int = 3,
modulation_bias: bool = True,
)#

Bases: megatron.core.transformer.transformer_layer.TransformerLayer

FLUX Single Transformer Block.

Single transformer layer mathematically equivalent to original Flux single transformer. This layer is re-implemented with megatron-core and altered in structure for better performance.

Parameters:
  • config – Transformer configuration.

  • submodules – Transformer layer submodules.

  • layer_number – Layer index.

  • mlp_ratio – MLP hidden size ratio.

  • n_adaln_chunks – Number of adaptive LN chunks.

  • modulation_bias – Whether to use bias in modulation.

Initialization

forward(
hidden_states,
attention_mask=None,
context=None,
context_mask=None,
rotary_pos_emb=None,
inference_params=None,
packed_seq_params=None,
emb=None,
)#

Forward pass for FLUX single transformer block.

Parameters:
  • hidden_states – Input hidden states.

  • attention_mask – Attention mask.

  • context – Context tensor (unused).

  • context_mask – Context mask (unused).

  • rotary_pos_emb – Rotary position embeddings.

  • inference_params – Inference parameters.

  • packed_seq_params – Packed sequence parameters.

  • emb – Timestep/conditioning embedding.

Returns:

Tuple of (hidden_states, None).

__call__(*args, **kwargs)#
bridge.diffusion.models.flux.flux_layer_spec.get_flux_double_transformer_engine_spec() megatron.core.transformer.spec_utils.ModuleSpec#

Get the module specification for FLUX double transformer blocks.

Returns:

ModuleSpec for MMDiTLayer with JointSelfAttention.

bridge.diffusion.models.flux.flux_layer_spec.get_flux_single_transformer_engine_spec() megatron.core.transformer.spec_utils.ModuleSpec#

Get the module specification for FLUX single transformer blocks.

Returns:

ModuleSpec for FluxSingleTransformerBlock with FluxSingleAttention.