bridge.diffusion.models.common.dit_attention#
Module Contents#
Classes#
Configuration class for specifying the submodules of a cross-attention. |
|
API#
- class bridge.diffusion.models.common.dit_attention.DiTCrossAttentionSubmodules#
Configuration class for specifying the submodules of a cross-attention.
- linear_q: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- linear_kv: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- core_attention: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- linear_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- q_layernorm: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- k_layernorm: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- class bridge.diffusion.models.common.dit_attention.DiTSelfAttention(
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- submodules: megatron.core.transformer.attention.SelfAttentionSubmodules,
- layer_number: int,
- attn_mask_type: megatron.core.transformer.enums.AttnMaskType,
- cp_comm_type: str = None,
- pg_collection=None,
Bases:
megatron.core.transformer.attention.SelfAttentionInitialization
- get_query_key_value_tensors(
- hidden_states,
- key_value_states=None,
- output_gate=None,
- split_qkv=True,
Derives
query,keyandvaluetensors fromhidden_states.
- class bridge.diffusion.models.common.dit_attention.DiTCrossAttention(
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- submodules: bridge.diffusion.models.common.dit_attention.DiTCrossAttentionSubmodules,
- layer_number: int,
- attn_mask_type: megatron.core.transformer.enums.AttnMaskType,
- cp_comm_type: str = None,
- pg_collection=None,
Bases:
megatron.core.transformer.attention.CrossAttentionInitialization
- get_query_key_value_tensors(
- hidden_states,
- key_value_states,
- output_gate=None,
- split_qkv=True,
Derives
querytensor fromhidden_states, andkey/valuetensors fromkey_value_states.