attention_pair_bias#

cuequivariance_torch.attention_pair_bias(
s,
q,
k,
v,
z,
mask,
num_heads,
w_proj_z,
w_proj_g,
w_proj_o,
w_ln_z,
b_ln_z,
b_proj_z=None,
b_proj_g=None,
b_proj_o=None,
inf=1000000.0,
eps=1e-05,
attn_scale=None,
compute_pair_bias=True,
multiplicity=1,
)#

Compute attention with pairwise bias for diffusion models.

This function implements attention with pairwise bias, which is commonly used in diffusion models. The function automatically chooses between optimized Triton kernels (for long sequences) and PyTorch fallback (for short sequences) based on sequence length.

Parameters:
  • s (Tensor) – Input sequence tensor of shape (B * M, S, D) where B is batch size, M is multiplicity (diffusion steps), S is sequence length, and D is feature dimension.

  • q (Tensor) – Query tensor of shape (B * M, H, U, DH) where H is number of heads, U is query sequence length, and DH is head dimension.

  • k (Tensor) – Key tensor of shape (B * M, H, V, DH) where V is key sequence length.

  • v (Tensor) – Value tensor of shape (B * M, H, V, DH).

  • z (Tensor) – Pairwise tensor of shape (B, U, V, z_dim) containing pairwise interactions, where z_dim can be arbitrary. This is the main input for the pairwise bias computation.

  • mask (Tensor) – Attention mask of shape (B, V) or (B * M, V) indicating which positions should be masked (0 = masked, 1 = unmasked).

  • num_heads (int) – Number of attention heads.

  • w_proj_z (Tensor) – Weight matrix for z projection of shape (H, z_dim).

  • w_proj_g (Tensor) – Weight matrix for gating projection of shape (D, D).

  • w_proj_o (Tensor) – Weight matrix for output projection of shape (D, D).

  • w_ln_z (Tensor) – Weight for layer normalization of z tensor of shape (z_dim,).

  • b_ln_z (Tensor) – Bias for layer normalization of z tensor of shape (z_dim,).

  • b_proj_z (Tensor | None) – Bias for z projection of shape (H,). Defaults to None.

  • b_proj_g (Tensor | None) – Bias for gating projection of shape (D,). Defaults to None.

  • b_proj_o (Tensor | None) – Bias for output projection of shape (D,). Defaults to None.

  • inf (float | None) – Large value used for masking invalid attention positions. Defaults to 1e6.

  • eps (float | None) – Epsilon value for layer normalization. Defaults to 1e-5.

  • attn_scale (float | None) – Scaling factor for attention scores. If None, uses 1/sqrt(head_dim). Defaults to None.

  • compute_pair_bias (bool | None) – Whether to compute pairwise bias. If False, z tensor should already be in the correct format (B, U, V, H). Defaults to True.

  • multiplicity (int | None) – Multiplicity (diffusion steps). Should be explicitly set if multiplicity > 1 and is not reflected in z tensor. Defaults to 1.

Returns:

  • output (torch.Tensor): Attention output of shape (B * M, S, D) with pairwise bias applied.

  • proj_z (torch.Tensor): Projected z tensor of shape (B, H, U, V) containing the pairwise bias tensor with mask applied.

Return type:

A tuple containing

Notes

  • For short sequences (≤ CUEQ_ATTENTION_PAIR_BIAS_FALLBACK_THRESHOLD), uses PyTorch fallback implementation.

  • For long sequences, uses optimized Triton kernels with automatic backend selection (CUDNN, Flash Attention, Efficient Attention).

  • The multiplicity parameter (M) allows processing multiple diffusion timesteps in a single forward pass.

  • The proj_z output is experimental to prevent breakage when caching of pair bias tensor is enabled in the next release.

Examples

>>> import torch
>>> from cuequivariance_torch import attention_pair_bias
>>> if torch.cuda.is_available():
...     device = torch.device("cuda")
...     batch_size, seq_len, num_heads, heads_dim, hidden_dim = 1, 32, 2, 32, 64
...     query_len, key_len, z_dim = 32, 32, 16
...     # Create input tensors on GPU
...     s = torch.randn(batch_size, seq_len, hidden_dim,
...                     device=device, dtype=torch.float32)
...     q = torch.randn(batch_size, num_heads, query_len, heads_dim,
...                     device=device, dtype=torch.float32)
...     k = torch.randn(batch_size, num_heads, key_len, heads_dim,
...                     device=device, dtype=torch.float32)
...     v = torch.randn(batch_size, num_heads, key_len, heads_dim,
...                     device=device, dtype=torch.float32)
...     z = torch.randn(batch_size, query_len, key_len, z_dim,
...                     device=device, dtype=torch.float32)
...     mask = torch.rand(batch_size, key_len,
...                       device=device) < 0.5
...     w_proj_z = torch.randn(num_heads, z_dim,
...                     device=device, dtype=torch.float32)
...     w_proj_g = torch.randn(hidden_dim, hidden_dim,
...                     device=device, dtype=torch.float32)
...     w_proj_o = torch.randn(hidden_dim, hidden_dim,
...                     device=device, dtype=torch.float32)
...     w_ln_z = torch.randn(z_dim,
...                     device=device, dtype=torch.float32)
...     b_ln_z = torch.randn(z_dim,
...                     device=device, dtype=torch.float32)
...     # Perform operation
...     output, proj_z = attention_pair_bias(
...         s=s,
...         q=q,
...         k=k,
...         v=v,
...         z=z,
...         mask=mask,
...         num_heads=num_heads,
...         w_proj_z=w_proj_z,
...         w_proj_g=w_proj_g,
...         w_proj_o=w_proj_o,
...         w_ln_z=w_ln_z,
...         b_ln_z=b_ln_z,
...     )
...     print(output.shape)  # torch.Size([1, 32, 64])
torch.Size([1, 32, 64])