bridge.diffusion.models.common.dit_attention#

Module Contents#

Classes#

DiTCrossAttentionSubmodules

Configuration class for specifying the submodules of a cross-attention.

DiTSelfAttention

DiTCrossAttention

API#

class bridge.diffusion.models.common.dit_attention.DiTCrossAttentionSubmodules#

Configuration class for specifying the submodules of a cross-attention.

linear_q: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

linear_kv: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

core_attention: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

linear_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

q_layernorm: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

k_layernorm: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

class bridge.diffusion.models.common.dit_attention.DiTSelfAttention(
config: megatron.core.transformer.transformer_config.TransformerConfig,
submodules: megatron.core.transformer.attention.SelfAttentionSubmodules,
layer_number: int,
attn_mask_type: megatron.core.transformer.enums.AttnMaskType,
cp_comm_type: str = None,
pg_collection=None,
)#

Bases: megatron.core.transformer.attention.SelfAttention

Initialization

get_query_key_value_tensors(
hidden_states,
key_value_states=None,
output_gate=None,
split_qkv=True,
)#

Derives query, key and value tensors from hidden_states.

class bridge.diffusion.models.common.dit_attention.DiTCrossAttention(
config: megatron.core.transformer.transformer_config.TransformerConfig,
submodules: bridge.diffusion.models.common.dit_attention.DiTCrossAttentionSubmodules,
layer_number: int,
attn_mask_type: megatron.core.transformer.enums.AttnMaskType,
cp_comm_type: str = None,
pg_collection=None,
)#

Bases: megatron.core.transformer.attention.CrossAttention

Initialization

get_query_key_value_tensors(
hidden_states,
key_value_states,
output_gate=None,
split_qkv=True,
)#

Derives query tensor from hidden_states, and key/value tensors from key_value_states.