bridge.diffusion.models.flux.flux_attention#
FLUX attention modules for diffusion models.
Module Contents#
Classes#
Submodules for Joint Self-attention layer. |
|
Joint Self-attention layer class. |
|
Self-attention layer class for FLUX single transformer blocks. |
API#
- class bridge.diffusion.models.flux.flux_attention.JointSelfAttentionSubmodules#
Submodules for Joint Self-attention layer.
Used for MMDIT-like transformer blocks in FLUX.
- linear_qkv: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- added_linear_qkv: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- core_attention: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- linear_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- q_layernorm: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- k_layernorm: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- added_q_layernorm: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- added_k_layernorm: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- class bridge.diffusion.models.flux.flux_attention.JointSelfAttention(
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- submodules: bridge.diffusion.models.flux.flux_attention.JointSelfAttentionSubmodules,
- layer_number: int,
- attn_mask_type=AttnMaskType.padding,
- context_pre_only: bool = False,
- **kwargs,
Bases:
megatron.core.transformer.attention.AttentionJoint Self-attention layer class.
Used for MMDIT-like transformer blocks in FLUX double blocks. This attention layer processes both image hidden states and text encoder hidden states jointly.
- Parameters:
config – Transformer configuration.
submodules – Joint self-attention submodules specification.
layer_number – Layer index in the transformer.
attn_mask_type – Type of attention mask to use.
context_pre_only – Whether to only use context for pre-processing.
Initialization
- _split_qkv(mixed_qkv)#
Split mixed QKV tensor into separate Q, K, V tensors.
- get_query_key_value_tensors(hidden_states, key_value_states=None)#
Derives
query,keyandvaluetensors fromhidden_states.
- get_added_query_key_value_tensors(
- added_hidden_states,
- key_value_states=None,
Derives
query,keyandvaluetensors fromadded_hidden_states.
- forward(
- hidden_states,
- attention_mask,
- key_value_states=None,
- inference_params=None,
- rotary_pos_emb=None,
- packed_seq_params=None,
- additional_hidden_states=None,
Forward pass for joint self-attention.
- Parameters:
hidden_states – Image hidden states [sq, b, h].
attention_mask – Attention mask.
key_value_states – Optional key-value states.
inference_params – Inference parameters.
rotary_pos_emb – Rotary position embeddings.
packed_seq_params – Packed sequence parameters.
additional_hidden_states – Text encoder hidden states.
- Returns:
Tuple of (image_attention_output, encoder_attention_output).
- class bridge.diffusion.models.flux.flux_attention.FluxSingleAttention(
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- submodules: megatron.core.transformer.attention.SelfAttentionSubmodules,
- layer_number: int,
- attn_mask_type=AttnMaskType.padding,
- cp_comm_type: str = None,
- **kwargs,
Bases:
megatron.core.transformer.attention.SelfAttentionSelf-attention layer class for FLUX single transformer blocks.
Self-attention layer takes input with size [s, b, h] and returns output of the same size.
- Parameters:
config – Transformer configuration.
submodules – Self-attention submodules specification.
layer_number – Layer index in the transformer.
attn_mask_type – Type of attention mask to use.
cp_comm_type – Context parallel communication type.
Initialization
- forward(
- hidden_states,
- attention_mask,
- key_value_states=None,
- inference_params=None,
- rotary_pos_emb=None,
- packed_seq_params=None,
Forward pass for FLUX single attention.
- Parameters:
hidden_states – Input hidden states [sq, b, h].
attention_mask – Attention mask.
key_value_states – Optional key-value states.
inference_params – Inference parameters.
rotary_pos_emb – Rotary position embeddings.
packed_seq_params – Packed sequence parameters.
- Returns:
Attention output tensor.