triangle_multiplicative_update#

cuequivariance_torch.triangle_multiplicative_update(
x,
direction='outgoing',
mask=None,
norm_in_weight=None,
norm_in_bias=None,
p_in_weight=None,
p_in_bias=None,
g_in_weight=None,
g_in_bias=None,
norm_out_weight=None,
norm_out_bias=None,
p_out_weight=None,
p_out_bias=None,
g_out_weight=None,
g_out_bias=None,
eps=1e-05,
precision=None,
)#

Apply triangle multiplicative update operation.

This function performs a triangle multiplicative update operation, which is a key component in the AlphaFold2 architecture. The operation consists of:

  1. Input normalization and gating

  2. Triangular projection (either outgoing or incoming)

  3. Output normalization and gating

The function supports both ahead-of-time (AOT) tuning and just-in-time (JIT) tuning. Auto-tuning behavior can be controlled through environment variables:

  • Quick testing: Default configuration where tuning configs, if existent, are looked-up. If not, then falls back to default kernel parameters. No tuning is performed.

  • On-Demand tuning: Set CUEQ_TRITON_TUNING_MODE = “ONDEMAND” to auto-tune for new shapes encountered on first run (may take several minutes)

  • AOT tuning: Set CUEQ_TRITON_TUNING_MODE = “AOT” to perform full ahead-of-time tuning for optimal performance (may take several hours)

  • Ignore user cache: Set CUEQ_TRITON_IGNORE_EXISTING_CACHE to ignore both the default settings that come with the package and any user-local settings previously saved with AOT/ONDEMAND tuning. May be used to regenerate optimal settings for a particular setup.

  • Cache directory: Set CUEQ_TRITON_CACHE_DIR to specify where tuning configurations are stored

  • Note: When using Docker with default or on-demand tuning enabled, commit the container to persist tuning changes

Parameters:
  • x (torch.Tensor) – Input tensor of shape (B, N, N, D) where: B is the batch size N is the sequence length D is the hidden dimension

  • direction (str) – Direction of the triangular projection. Must be either “outgoing” or “incoming”.

  • mask (torch.Tensor) – Optional Mask tensor of shape (B, N, N) for masking the output.

  • norm_in_weight (torch.Tensor) – Optional weight tensor for input normalization of shape (D,).

  • norm_in_bias (torch.Tensor) – Optional bias tensor for input normalization of shape (D,).

  • p_in_weight (torch.Tensor) – Optional weight tensor for input projection of shape (2D, D).

  • p_in_bias (torch.Tensor) – Optional bias tensor for input projection of shape (2D,).

  • g_in_weight (torch.Tensor) – Optional weight tensor for input gating of shape (2D, D).

  • g_in_bias (torch.Tensor) – Optional bias tensor for input gating of shape (2D,).

  • norm_out_weight (torch.Tensor) – Optional weight tensor for output normalization of shape (D,).

  • norm_out_bias (torch.Tensor) – Optional bias tensor for output normalization of shape (D,).

  • p_out_weight (torch.Tensor) – Optional weight tensor for output projection of shape (D, D).

  • p_out_bias (torch.Tensor) – Optional bias tensor for output projection of shape (D,).

  • g_out_weight (torch.Tensor) – Optional weight tensor for output gating of shape (D, D).

  • g_out_bias (torch.Tensor) – Optional bias tensor for output gating of shape (D,).

  • eps (float, optional) – Small constant for numerical stability in normalization. Defaults to 1e-5.

  • precision (TriMulPrecision, optional) – Precision mode for matrix multiplications. Available options: - None: Defaults to triton language dot’s default for non-32b input and for 32b input, tf32/tf32x3 based on 1/0 value set in torch.backends.cuda.matmul.allow_tf32 - IEEE: Use IEEE 754 precision

Returns:

Output tensor of shape (batch_size, seq_len, seq_len, hidden_dim)

Return type:

Tensor

Notes

  1. Context is saved for backward pass. You don’t need to save it manually.

  2. Kernel precision (fp32, bf16, fp16) is based on input dtypes. For tf32, set it from torch global scope using torch.backends.cuda.matmul.allow_tf32

  3. Limitation: Currently only supports hidden_dim values that are multiples of 32.

  4. We have moved away from the default round-towards-zero (RZ) implementation to round-nearest (RN) for better tf32 accuracy in cuex.triangle_multiplicative_update. In rare circumstances, this may cause minor differences in results observed.

  5. When using torch compile, use cueuivariance_ops_torch.init_triton_cache() to initialize triton cache before calling torch compiled triangular multiplicative update.

Example

>>> import torch
>>> from cuequivariance_torch import triangle_multiplicative_update
>>> if torch.cuda.is_available():
...     device = torch.device("cuda")
...     batch_size, seq_len, hidden_dim = 1, 128, 128
...     # Create input tensor
...     x = torch.randn(batch_size, seq_len, seq_len, hidden_dim, requires_grad=True, device=device)
...     # Create mask (1 for valid positions, 0 for masked)
...     mask = torch.ones(batch_size, seq_len, seq_len, device=device)
...     # Perform triangular multiplication
...     output = triangle_multiplicative_update(
...         x=x,
...         direction="outgoing",  # or "incoming"
...         mask=mask,
...     )
...     print(output.shape)  # torch.Size([1, 128, 128, 128])
...     # Create gradient tensor and perform backward pass
...     grad_out = torch.randn_like(output)
...     output.backward(grad_out)
...     # Access gradients
...     print(x.grad.shape)  # torch.Size([1, 128, 128, 128])
torch.Size([1, 128, 128, 128])
torch.Size([1, 128, 128, 128])