bridge.diffusion.models.flux.flux_layer_spec#
FLUX layer specifications and transformer blocks.
Module Contents#
Classes#
Adaptive Layer Normalization Module for DiT/FLUX models. |
|
A variant of AdaLN used for FLUX models. |
|
Multi-modal transformer layer for FLUX double blocks. |
|
FLUX Single Transformer Block. |
Functions#
Get the module specification for FLUX double transformer blocks. |
|
Get the module specification for FLUX single transformer blocks. |
API#
- class bridge.diffusion.models.flux.flux_layer_spec.AdaLN(
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- n_adaln_chunks: int = 9,
- norm=nn.LayerNorm,
- modulation_bias: bool = False,
- use_second_norm: bool = False,
Bases:
megatron.core.transformer.module.MegatronModuleAdaptive Layer Normalization Module for DiT/FLUX models.
Implements adaptive layer normalization that conditions on timestep embeddings.
- Parameters:
config β Transformer configuration.
n_adaln_chunks β Number of adaptive LN chunks for modulation outputs.
norm β Normalization type to use.
modulation_bias β Whether to use bias in modulation layers.
use_second_norm β Whether to use a second layer norm.
Initialization
- forward(timestep_emb)#
Apply adaptive layer normalization modulation.
- modulate(x, shift, scale)#
Apply modulation with shift and scale.
- scale_add(residual, x, gate)#
Add gated output to residual.
- modulated_layernorm(x, shift, scale, layernorm_idx=0)#
Apply layer norm followed by modulation.
- scaled_modulated_layernorm(
- residual,
- x,
- gate,
- shift,
- scale,
- layernorm_idx=0,
Apply scale, add, and modulated layer norm.
- class bridge.diffusion.models.flux.flux_layer_spec.AdaLNContinuous(
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- conditioning_embedding_dim: int,
- modulation_bias: bool = True,
- norm_type: str = 'layer_norm',
Bases:
megatron.core.transformer.module.MegatronModuleA variant of AdaLN used for FLUX models.
Continuous adaptive layer normalization that outputs scale and shift directly from conditioning embeddings.
- Parameters:
config β Transformer configuration.
conditioning_embedding_dim β Dimension of the conditioning embedding.
modulation_bias β Whether to use bias in modulation layer.
norm_type β Type of normalization (βlayer_normβ or βrms_normβ).
Initialization
- forward(
- x: torch.Tensor,
- conditioning_embedding: torch.Tensor,
Apply continuous adaptive layer normalization.
- class bridge.diffusion.models.flux.flux_layer_spec.MMDiTLayer(
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- submodules: megatron.core.transformer.transformer_layer.TransformerLayerSubmodules,
- layer_number: int = 1,
- context_pre_only: bool = False,
Bases:
megatron.core.transformer.transformer_layer.TransformerLayerMulti-modal transformer layer for FLUX double blocks.
Transformer layer takes input with size [s, b, h] and returns an output of the same size.
MMDiT layer implementation from [https://arxiv.org/pdf/2403.03206].
- Parameters:
config β Transformer configuration.
submodules β Transformer layer submodules.
layer_number β Layer index.
context_pre_only β Whether to only use context for pre-processing.
Initialization
- forward(
- hidden_states,
- encoder_hidden_states,
- attention_mask=None,
- context=None,
- context_mask=None,
- rotary_pos_emb=None,
- inference_params=None,
- packed_seq_params=None,
- emb=None,
Forward pass for MMDiT layer.
- Parameters:
hidden_states β Image hidden states.
encoder_hidden_states β Text encoder hidden states.
attention_mask β Attention mask.
context β Context tensor (unused).
context_mask β Context mask (unused).
rotary_pos_emb β Rotary position embeddings.
inference_params β Inference parameters.
packed_seq_params β Packed sequence parameters.
emb β Timestep/conditioning embedding.
- Returns:
Tuple of (hidden_states, encoder_hidden_states).
- __call__(*args, **kwargs)#
- class bridge.diffusion.models.flux.flux_layer_spec.FluxSingleTransformerBlock(
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- submodules: megatron.core.transformer.transformer_layer.TransformerLayerSubmodules,
- layer_number: int = 1,
- mlp_ratio: int = 4,
- n_adaln_chunks: int = 3,
- modulation_bias: bool = True,
Bases:
megatron.core.transformer.transformer_layer.TransformerLayerFLUX Single Transformer Block.
Single transformer layer mathematically equivalent to original Flux single transformer. This layer is re-implemented with megatron-core and altered in structure for better performance.
- Parameters:
config β Transformer configuration.
submodules β Transformer layer submodules.
layer_number β Layer index.
mlp_ratio β MLP hidden size ratio.
n_adaln_chunks β Number of adaptive LN chunks.
modulation_bias β Whether to use bias in modulation.
Initialization
- forward(
- hidden_states,
- attention_mask=None,
- context=None,
- context_mask=None,
- rotary_pos_emb=None,
- inference_params=None,
- packed_seq_params=None,
- emb=None,
Forward pass for FLUX single transformer block.
- Parameters:
hidden_states β Input hidden states.
attention_mask β Attention mask.
context β Context tensor (unused).
context_mask β Context mask (unused).
rotary_pos_emb β Rotary position embeddings.
inference_params β Inference parameters.
packed_seq_params β Packed sequence parameters.
emb β Timestep/conditioning embedding.
- Returns:
Tuple of (hidden_states, None).
- __call__(*args, **kwargs)#
- bridge.diffusion.models.flux.flux_layer_spec.get_flux_double_transformer_engine_spec() megatron.core.transformer.spec_utils.ModuleSpec#
Get the module specification for FLUX double transformer blocks.
- Returns:
ModuleSpec for MMDiTLayer with JointSelfAttention.
- bridge.diffusion.models.flux.flux_layer_spec.get_flux_single_transformer_engine_spec() megatron.core.transformer.spec_utils.ModuleSpec#
Get the module specification for FLUX single transformer blocks.
- Returns:
ModuleSpec for FluxSingleTransformerBlock with FluxSingleAttention.