bridge.diffusion.models.flux.flux_attention#

FLUX attention modules for diffusion models.

Module Contents#

Classes#

JointSelfAttentionSubmodules

Submodules for Joint Self-attention layer.

JointSelfAttention

Joint Self-attention layer class.

FluxSingleAttention

Self-attention layer class for FLUX single transformer blocks.

API#

class bridge.diffusion.models.flux.flux_attention.JointSelfAttentionSubmodules#

Submodules for Joint Self-attention layer.

Used for MMDIT-like transformer blocks in FLUX.

linear_qkv: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

added_linear_qkv: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

core_attention: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

linear_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

q_layernorm: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

k_layernorm: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

added_q_layernorm: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

added_k_layernorm: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

class bridge.diffusion.models.flux.flux_attention.JointSelfAttention(
config: megatron.core.transformer.transformer_config.TransformerConfig,
submodules: bridge.diffusion.models.flux.flux_attention.JointSelfAttentionSubmodules,
layer_number: int,
attn_mask_type=AttnMaskType.padding,
context_pre_only: bool = False,
**kwargs,
)#

Bases: megatron.core.transformer.attention.Attention

Joint Self-attention layer class.

Used for MMDIT-like transformer blocks in FLUX double blocks. This attention layer processes both image hidden states and text encoder hidden states jointly.

Parameters:
  • config – Transformer configuration.

  • submodules – Joint self-attention submodules specification.

  • layer_number – Layer index in the transformer.

  • attn_mask_type – Type of attention mask to use.

  • context_pre_only – Whether to only use context for pre-processing.

Initialization

_split_qkv(mixed_qkv)#

Split mixed QKV tensor into separate Q, K, V tensors.

get_query_key_value_tensors(hidden_states, key_value_states=None)#

Derives query, key and value tensors from hidden_states.

get_added_query_key_value_tensors(
added_hidden_states,
key_value_states=None,
)#

Derives query, key and value tensors from added_hidden_states.

forward(
hidden_states,
attention_mask,
key_value_states=None,
inference_params=None,
rotary_pos_emb=None,
packed_seq_params=None,
additional_hidden_states=None,
)#

Forward pass for joint self-attention.

Parameters:
  • hidden_states – Image hidden states [sq, b, h].

  • attention_mask – Attention mask.

  • key_value_states – Optional key-value states.

  • inference_params – Inference parameters.

  • rotary_pos_emb – Rotary position embeddings.

  • packed_seq_params – Packed sequence parameters.

  • additional_hidden_states – Text encoder hidden states.

Returns:

Tuple of (image_attention_output, encoder_attention_output).

class bridge.diffusion.models.flux.flux_attention.FluxSingleAttention(
config: megatron.core.transformer.transformer_config.TransformerConfig,
submodules: megatron.core.transformer.attention.SelfAttentionSubmodules,
layer_number: int,
attn_mask_type=AttnMaskType.padding,
cp_comm_type: str = None,
**kwargs,
)#

Bases: megatron.core.transformer.attention.SelfAttention

Self-attention layer class for FLUX single transformer blocks.

Self-attention layer takes input with size [s, b, h] and returns output of the same size.

Parameters:
  • config – Transformer configuration.

  • submodules – Self-attention submodules specification.

  • layer_number – Layer index in the transformer.

  • attn_mask_type – Type of attention mask to use.

  • cp_comm_type – Context parallel communication type.

Initialization

forward(
hidden_states,
attention_mask,
key_value_states=None,
inference_params=None,
rotary_pos_emb=None,
packed_seq_params=None,
)#

Forward pass for FLUX single attention.

Parameters:
  • hidden_states – Input hidden states [sq, b, h].

  • attention_mask – Attention mask.

  • key_value_states – Optional key-value states.

  • inference_params – Inference parameters.

  • rotary_pos_emb – Rotary position embeddings.

  • packed_seq_params – Packed sequence parameters.

Returns:

Attention output tensor.