attention_pair_bias#
- cuequivariance_torch.attention_pair_bias(
- s,
- q,
- k,
- v,
- z,
- mask,
- num_heads,
- w_proj_z,
- w_proj_g,
- w_proj_o,
- w_ln_z=None,
- b_ln_z=None,
- b_proj_z=None,
- b_proj_g=None,
- b_proj_o=None,
- inf=1000000.0,
- eps=1e-05,
- attn_scale=None,
- return_z_proj=True,
- is_cached_z_proj=False,
Compute attention with pairwise bias for diffusion models.
This function implements attention with pairwise bias, which is commonly used in diffusion models. The function automatically chooses between optimized Triton kernels (for long sequences) and PyTorch fallback (for short sequences) based on sequence length.
- Parameters:
s (Tensor) – Input sequence tensor of shape (B * M, S, D) where B is batch size, M is multiplicity (diffusion steps), S is sequence length, and D is feature dimension.
q (Tensor) – Query tensor of shape (B * M, H, U, DH) where H is number of heads, U is query sequence length, and DH is head dimension.
k (Tensor) – Key tensor of shape (B * M, H, V, DH) where V is key sequence length.
v (Tensor) – Value tensor of shape (B * M, H, V, DH).
z (Tensor) – Pairwise tensor of shape (B, U, V, z_dim) containing pairwise interactions, where z_dim can be arbitrary. This is the main input for the pairwise bias computation. If return_z_proj is True, z should be of shape (B, H, U, V).
mask (Tensor) – Attention mask of shape (B, V) or (B * M, V) indicating which positions should be masked (0 = masked, 1 = unmasked).
num_heads (int) – Number of attention heads.
w_proj_z (Tensor | None) – Weight matrix for z projection of shape (H, z_dim).
w_proj_g (Tensor) – Weight matrix for gating projection of shape (D, D).
w_proj_o (Tensor) – Weight matrix for output projection of shape (D, D).
w_ln_z (Tensor | None) – Weight for layer normalization of z tensor of shape (z_dim,).
b_ln_z (Tensor | None) – Bias for layer normalization of z tensor of shape (z_dim,).
b_proj_z (Tensor | None) – Bias for z projection of shape (H,). Defaults to None.
b_proj_g (Tensor | None) – Bias for gating projection of shape (D,). Defaults to None.
b_proj_o (Tensor | None) – Bias for output projection of shape (D,). Defaults to None.
inf (float) – Large value used for masking invalid attention positions. Defaults to 1e6.
eps (float) – Epsilon value for layer normalization. Defaults to 1e-5.
attn_scale (float | None) – Scaling factor for attention scores. If None, uses 1/sqrt(head_dim). Defaults to None.
return_z_proj (bool) – Whether to return the projected z tensor as the second output. Defaults to True.
is_cached_z_proj (bool) – Whether the z tensor is already projected and cached. If True, z should be of shape (B, H, U, V). Defaults to False.
- Returns:
- Attention output of shape (B * M, S, D)
with pairwise bias applied.
proj_z (
torch.Tensor): Projected z tensor of shape (B, H, U, V) containing the pairwise bias tensor with mask applied.
- Return type:
output (
torch.Tensor)
Notes
For short sequences (≤ CUEQ_ATTENTION_PAIR_BIAS_FALLBACK_THRESHOLD), uses PyTorch fallback implementation.
For long sequences, uses optimized Triton kernels with automatic backend selection (CUDNN, Flash Attention, Efficient Attention).
Multiplicity (M) is computed automatically from tensor shapes to allow processing multiple diffusion timesteps in a single forward pass.
The proj_z output is experimental to prevent breakage when caching of pair bias tensor is enabled in the next release.
Tested for bf16, fp16, fp32 and tf32. torch.set_float32_matmul_precision maybe used to toggle between fp32/tf32.
Examples
>>> import torch >>> from cuequivariance_torch import attention_pair_bias >>> if torch.cuda.is_available(): ... device = torch.device("cuda") ... batch_size, seq_len, num_heads, heads_dim, hidden_dim = 1, 32, 2, 32, 64 ... query_len, key_len, z_dim = 32, 32, 16 ... # Create input tensors on GPU ... s = torch.randn(batch_size, seq_len, hidden_dim, ... device=device, dtype=torch.bfloat16) ... q = torch.randn(batch_size, num_heads, query_len, heads_dim, ... device=device, dtype=torch.bfloat16) ... k = torch.randn(batch_size, num_heads, key_len, heads_dim, ... device=device, dtype=torch.bfloat16) ... v = torch.randn(batch_size, num_heads, key_len, heads_dim, ... device=device, dtype=torch.bfloat16) ... z = torch.randn(batch_size, query_len, key_len, z_dim, ... device=device, dtype=torch.bfloat16) ... mask = torch.rand(batch_size, key_len, ... device=device) < 0.5 ... w_proj_z = torch.randn(num_heads, z_dim, ... device=device, dtype=torch.bfloat16) ... w_proj_g = torch.randn(hidden_dim, hidden_dim, ... device=device, dtype=torch.bfloat16) ... w_proj_o = torch.randn(hidden_dim, hidden_dim, ... device=device, dtype=torch.bfloat16) ... w_ln_z = torch.randn(z_dim, ... device=device, dtype=torch.bfloat16) ... b_ln_z = torch.randn(z_dim, ... device=device, dtype=torch.bfloat16) ... # Perform operation ... output, proj_z = attention_pair_bias( ... s=s, ... q=q, ... k=k, ... v=v, ... z=z, ... mask=mask, ... num_heads=num_heads, ... w_proj_z=w_proj_z, ... w_proj_g=w_proj_g, ... w_proj_o=w_proj_o, ... w_ln_z=w_ln_z, ... b_ln_z=b_ln_z, ... ) ... print(output.shape) # torch.Size([1, 32, 64]) torch.Size([1, 32, 64])