triangle_attention#
- cuequivariance_torch.triangle_attention(
- q,
- k,
- v,
- bias,
- mask=None,
- scale=None,
- return_aux=False,
Triangle Attention
\[\begin{split}\text{Attention}_q(Q, K, V, B, M) = \sum_k\left[\text{softmax}_k\left(\begin{cases} s\, Q_q \cdot K_k + B_{qk} & \text{if } M_k = 1 \\ -10^9 & \text{otherwise} \end{cases}\right) V_k \right]\end{split}\]- Parameters:
q (torch.Tensor) – Query tensor of shape (B, N, H, Q, D). For B=1, can also be (N, H, Q, D).
k (torch.Tensor) – Key tensor of shape (B, N, H, K, D). For B=1, can also be (N, H, K, D).
v (torch.Tensor) – Value tensor of shape (B, N, H, K, D). For B=1, can also be (N, H, K, D).
bias (torch.Tensor) – Bias tensor of shape (B, 1, H, Q, K), For B=1, can also be (1, H, Q, K). Will be cast to float32 for standard kernels. On Blackwell GPUs (sm100f, compute capability 10.0 or 10.3), will be cast to match q/k/v dtype (bf16/fp16) for best performance.
mask (torch.Tensor, optional) – Mask tensor of shape (B, N, 1, 1, K). For B=1, can also be (N, 1, 1, K). Will be cast to bool internally.
scale (float, optional) – Float scale for q (s in the equation). If None, value 1/sqrt(d) is used.
return_aux (bool) – If True, two auxiliary tensors are returned along with the result. Defaults to False.
- Return type:
Note
B: batch size
N: number of tokens
H: number of heads
Q: number of query tokens
K: number of key tokens
D: attention dimension
- Returns:
Output tensor of shape (B, N, H, Q, D). dtype=q.dtype - lse(torch.Tensor): Auxiliary result (for special use only). dtype=float32 - max(torch.Tensor): Auxiliary result (for special use only). dtype=float32
- Return type:
output(torch.Tensor)
- Parameters:
Notes
Context is saved for backward pass. You don’t need to save it manually.
Kernel precision (fp32, bf16, fp16) is based on input dtypes. For tf32, set it from torch global scope
Triangle attention kernel supports: all hidden_dim<=32 and divisible by 4 for tf32/fp32, and for all hidden_dim<=128 and divisible by 8 for bf16/fp16 (standard kernels). On Blackwell GPUs (compute capability 10.0 or 10.3), the sm100f kernel supports hidden_dim<=256 for forward passes and hidden_dim<=128 for backward passes. In the rare instance that the kernel does not support an input config, fallback to torch is enabled instead of erroring out.
Blackwell-optimized kernels (for compute capabilities 10.0 and 10.3) provide superior performance especially for long sequences and higher head dimensions. These kernels require the sequence length N to be a multiple of 8 for the forward pass; pad the sequence if necessary. The kernel provides optimal performance for a “padding mask” consisting in (all True, followed by all False) in the last dimension. Currently, this feature is supported only for cu13 builds.
Example
>>> import torch >>> import math >>> from cuequivariance_torch import triangle_attention >>> if torch.cuda.is_available(): ... device = torch.device("cuda") ... # Set up dimensions ... batch_size, seq_len, num_heads, hidden_dim = 1, 128, 2, 32 ... # Create input tensors on GPU with float16 precision ... q = torch.randn(batch_size, seq_len, num_heads, seq_len, hidden_dim, ... device=device, dtype=torch.float16, requires_grad=True) ... k = torch.randn(batch_size, seq_len, num_heads, seq_len, hidden_dim, ... device=device, dtype=torch.float16, requires_grad=True) ... v = torch.randn(batch_size, seq_len, num_heads, seq_len, hidden_dim, ... device=device, dtype=torch.float16, requires_grad=True) ... bias = torch.randn(batch_size, 1, num_heads, seq_len, seq_len, ... device=device, dtype=torch.float32, requires_grad=True) ... # Create optional mask ... mask = torch.rand(batch_size, seq_len, 1, 1, seq_len, ... device=device) < 0.5 ... # Calculate scale ... scale = 1 / math.sqrt(hidden_dim) ... # Forward pass ... output, lse, max_val = triangle_attention( ... q=q, k=k, v=v, bias=bias, mask=mask, scale=scale, return_aux=True) ... print(output.shape) # torch.Size([1, 128, 2, 128, 32]) ... # Create gradient tensor and perform backward pass ... grad_out = torch.randn_like(output) ... output.backward(grad_out) ... # Access gradients ... print(q.grad.shape) # torch.Size([1, 128, 2, 128, 32]) ... print(k.grad.shape) # torch.Size([1, 128, 2, 128, 32]) ... print(v.grad.shape) # torch.Size([1, 128, 2, 128, 32]) ... print(bias.grad.shape) # torch.Size([1, 1, 2, 128, 128]) torch.Size([1, 128, 2, 128, 32]) torch.Size([1, 128, 2, 128, 32]) torch.Size([1, 128, 2, 128, 32]) torch.Size([1, 128, 2, 128, 32]) torch.Size([1, 1, 2, 128, 128])