triangle_attention#
- cuequivariance_jax.triangle_attention(q, k, v, bias, mask, scale, precision=None)#
triangle attention
- Parameters:
q (Array) – Query tensor of shape [B, N, H, S_qo, D].
k (Array) – Key tensor of shape [B, N, H, S_kv, D].
v (Array) – Value tensor of shape [B, N, H, S_kv, D].
bias (Array) – Bias tensor of shape [B, 1, H, S_qo, S_kv].
mask (Array) – Mask tensor of shape [B, N, 1, 1, S_kv] (boolean, True means valid).
scale (float) – Scaling factor for the dot product.
precision (Precision | None) – Precision for the computation (default is None).
- Returns:
A tuple containing the attention output, log-sum-exp, and maximum value.
\[\text{Attention}_a(Q, K, V, M, T) = \sum_b \mathrm{softmax}_b\left( M_b \cdot (Q_a K_b + T_{ab}) + (1 - M_b) \cdot (-10^9) \right) V_b\]where \(Q\), \(K\), and \(V\) are the query, key, and value tensors, \(M\) is the mask bias, and \(T\) is the triangle bias.
Note
This operation uses a custom CUDA kernel for performance. When using this function on multiple devices, manual sharding is required to achieve proper performance. Without explicit sharding, performance will be significantly degraded. See JAX shard_map documentation for details on manual parallelism.