triangle_attention#

cuequivariance_jax.triangle_attention(q, k, v, bias, mask, scale, precision=None)#

triangle attention

Parameters:
  • q (Array) – Query tensor of shape [B, N, H, S_qo, D].

  • k (Array) – Key tensor of shape [B, N, H, S_kv, D].

  • v (Array) – Value tensor of shape [B, N, H, S_kv, D].

  • bias (Array) – Bias tensor of shape [B, 1, H, S_qo, S_kv].

  • mask (Array) – Mask tensor of shape [B, N, 1, 1, S_kv] (boolean, True means valid).

  • scale (float) – Scaling factor for the dot product.

  • precision (Precision | None) – Precision for the computation (default is None).

Returns:

A tuple containing the attention output, log-sum-exp, and maximum value.

\[\text{Attention}_a(Q, K, V, M, T) = \sum_b \mathrm{softmax}_b\left( M_b \cdot (Q_a K_b + T_{ab}) + (1 - M_b) \cdot (-10^9) \right) V_b\]

where \(Q\), \(K\), and \(V\) are the query, key, and value tensors, \(M\) is the mask bias, and \(T\) is the triangle bias.