triangle_attention#
- cuequivariance_jax.triangle_attention(q, k, v, bias, mask, scale, precision=None)#
triangle attention
- Parameters:
q (Array) – Query tensor of shape [B, N, H, S_qo, D].
k (Array) – Key tensor of shape [B, N, H, S_kv, D].
v (Array) – Value tensor of shape [B, N, H, S_kv, D].
bias (Array) – Bias tensor of shape [B, 1, H, S_qo, S_kv].
mask (Array) – Mask tensor of shape [B, N, 1, 1, S_kv] (boolean, True means valid).
scale (float) – Scaling factor for the dot product.
precision (Precision | None) – Precision for the computation (default is None).
- Returns:
A tuple containing the attention output, log-sum-exp, and maximum value.
\[\text{Attention}_a(Q, K, V, M, T) = \sum_b \mathrm{softmax}_b\left( M_b \cdot (Q_a K_b + T_{ab}) + (1 - M_b) \cdot (-10^9) \right) V_b\]where \(Q\), \(K\), and \(V\) are the query, key, and value tensors, \(M\) is the mask bias, and \(T\) is the triangle bias.