triangle_attention#
- cuequivariance_torch.triangle_attention(
- q: Tensor,
- k: Tensor,
- v: Tensor,
- bias: Tensor,
- mask: Tensor | None = None,
- scale: float | None = None,
- return_aux: bool = False,
Triangle Attention
\[\begin{split}\text{Attention}_q(Q, K, V, B, M) = \sum_k\left[\text{softmax}_k\left(\begin{cases} s\, Q_q \cdot K_k + B_{qk} & \text{if } M_k = 1 \\ -10^9 & \text{otherwise} \end{cases}\right) V_k \right]\end{split}\]- Parameters:
q (torch.Tensor) – Query tensor of shape (B, N, H, Q, D). For B=1, can also be (N, H, Q, D).
k (torch.Tensor) – Key tensor of shape (B, N, H, K, D). For B=1, can also be (N, H, K, D).
v (torch.Tensor) – Value tensor of shape (B, N, H, K, D). For B=1, can also be (N, H, K, D).
bias (torch.Tensor) – Bias tensor of shape (B, 1, H, Q, K), For B=1, can also be (1, H, Q, K). Will be cast to float32 internally.
mask (torch.Tensor, optional) – Mask tensor of shape (B, N, 1, 1, K). For B=1, can also be (N, 1, 1, K). Will be cast to bool internally.
scale (float, optional) – Float scale for q (s in the equation). If None, value 1/sqrt(d) is used.
return_aux (bool) – If True, two auxiliary tensors are returned along with the result. Defaults to False.
Note
B: batch size
N: number of tokens
H: number of heads
Q: number of query tokens
K: number of key tokens
D: attention dimension
- Returns:
Output tensor of shape (B, N, H, Q, D). dtype=q.dtype - lse(torch.Tensor): Auxiliary result (for special use only). dtype=float32 - max(torch.Tensor): Auxiliary result (for special use only). dtype=float32
- Return type:
output(torch.Tensor)
Notes
Context is saved for backward pass. You don’t need to save it manually.
Kernel precision (fp32, bf16, fp16) is based on input dtypes. For tf32, set it from torch global scope
Limitation: Full FP32 is not supported for backward pass. Please set torch.backends.cuda.matmul.allow_tf32=True.
Example
>>> import torch >>> import math >>> from cuequivariance_torch import triangle_attention >>> if torch.cuda.is_available(): ... device = torch.device("cuda") ... # Set up dimensions ... batch_size, seq_len, num_heads, hidden_dim = 1, 128, 2, 32 ... # Create input tensors on GPU with float16 precision ... q = torch.randn(batch_size, seq_len, num_heads, seq_len, hidden_dim, ... device=device, dtype=torch.float16, requires_grad=True) ... k = torch.randn(batch_size, seq_len, num_heads, seq_len, hidden_dim, ... device=device, dtype=torch.float16, requires_grad=True) ... v = torch.randn(batch_size, seq_len, num_heads, seq_len, hidden_dim, ... device=device, dtype=torch.float16, requires_grad=True) ... bias = torch.randn(batch_size, 1, num_heads, seq_len, seq_len, ... device=device, dtype=torch.float32, requires_grad=True) ... # Create optional mask ... mask = torch.rand(batch_size, seq_len, 1, 1, seq_len, ... device=device) < 0.5 ... # Calculate scale ... scale = 1 / math.sqrt(hidden_dim) ... # Forward pass ... output, lse, max_val = triangle_attention( ... q=q, k=k, v=v, bias=bias, mask=mask, scale=scale, return_aux=True) ... print(output.shape) # torch.Size([1, 128, 2, 128, 32]) ... # Create gradient tensor and perform backward pass ... grad_out = torch.randn_like(output) ... output.backward(grad_out) ... # Access gradients ... print(q.grad.shape) # torch.Size([1, 128, 2, 128, 32]) ... print(k.grad.shape) # torch.Size([1, 128, 2, 128, 32]) ... print(v.grad.shape) # torch.Size([1, 128, 2, 128, 32]) ... print(bias.grad.shape) # torch.Size([1, 1, 2, 128, 128]) torch.Size([1, 128, 2, 128, 32]) torch.Size([1, 128, 2, 128, 32]) torch.Size([1, 128, 2, 128, 32]) torch.Size([1, 128, 2, 128, 32]) torch.Size([1, 1, 2, 128, 128])