triangle_attention#
- cuequivariance_jax.triangle_attention(q, k, v, bias, mask, scale, precision=None)#
triangle attention
- Parameters:
q (Array) – Query tensor of shape [B, N, H, S_qo, D].
k (Array) – Key tensor of shape [B, N, H, S_kv, D].
v (Array) – Value tensor of shape [B, N, H, S_kv, D].
bias (Array) – Bias tensor of shape [B, 1, H, S_qo, S_kv]. Cast to float32 for standard kernels; on Blackwell GPUs (sm100f), cast to match q/k/v dtype (bf16/fp16).
mask (Array) – Mask tensor of shape [B, N, 1, 1, S_kv] (boolean, True means valid).
scale (float) – Scaling factor for the dot product.
precision (Precision | None) – Precision for the computation (default is None).
- Returns:
A tuple containing the attention output, log-sum-exp, and maximum value.
\[\text{Attention}_a(Q, K, V, M, T) = \sum_b \mathrm{softmax}_b\left( M_b \cdot (Q_a K_b + T_{ab}) + (1 - M_b) \cdot (-10^9) \right) V_b\]where \(Q\), \(K\), and \(V\) are the query, key, and value tensors, \(M\) is the mask bias, and \(T\) is the triangle bias.
Note
This operation uses a custom CUDA kernel for performance. When using this function on multiple devices, manual sharding is required to achieve proper performance. Without explicit sharding, performance will be significantly degraded. See JAX shard_map documentation for details on manual parallelism.
Note
On Blackwell GPUs (cc 10.0 or 10.3, cu13 builds), the sm100f kernel supports hidden_dim<=256 for forward and hidden_dim<=128 for backward passes (bf16/fp16 only; hidden_dim must be divisible by 8). The sm100f forward kernel requires S_kv to be a multiple of 8, please pad if necessary. The kernel provides optimal performance for a “padding mask” consisting in (all True, followed by all False) in the last dimension.