triangle_multiplicative_update#

cuequivariance_torch.triangle_multiplicative_update(
x: Tensor,
direction: str = 'outgoing',
mask: Tensor | None = None,
norm_in_weight: Tensor | None = None,
norm_in_bias: Tensor | None = None,
p_in_weight: Tensor | None = None,
g_in_weight: Tensor | None = None,
norm_out_weight: Tensor | None = None,
norm_out_bias: Tensor | None = None,
p_out_weight: Tensor | None = None,
g_out_weight: Tensor | None = None,
eps: float = 1e-05,
) Tensor#

Apply triangle multiplicative update operation.

This function performs a triangle multiplicative update operation, which is a key component in the AlphaFold2 architecture. The operation consists of:

  1. Input normalization and gating

  2. Triangular projection (either outgoing or incoming)

  3. Output normalization and gating

The function supports both ahead-of-time (AOT) tuning and just-in-time (JIT) tuning. Auto-tuning behavior can be controlled through environment variables:

  • Default: Full Ahead-of-Time (AOT) auto-tuning enabled for optimal performance (may take several hours)

  • Quick testing: Set CUEQ_DISABLE_AOT_TUNING = 1 and CUEQ_DEFAULT_CONFIG = 1 to disable all tuning

  • On-Demand tuning: CUEQ_DISABLE_AOT_TUNING = 1, auto-tunes for new shapes encountered on first run. (may take several minutes)

  • Note: When using Docker with default or on-demand tuning enabled, commit the container to persist tuning changes

Parameters:
  • x (torch.Tensor) – Input tensor of shape (B, N, N, D) where: B is the batch size N is the sequence length D is the hidden dimension

  • direction (str) – Direction of the triangular projection. Must be either “outgoing” or “incoming”.

  • mask (torch.Tensor) – Optional Mask tensor of shape (B, N, N) for masking the output.

  • norm_in_weight (torch.Tensor) – Weight tensor for input normalization of shape (D,).

  • norm_in_bias (torch.Tensor) – Bias tensor for input normalization of shape (D,).

  • p_in_weight (torch.Tensor) – Weight tensor for input projection of shape (2D, D).

  • g_in_weight (torch.Tensor) – Weight tensor for input gating of shape (2D, D).

  • norm_out_weight (torch.Tensor) – Weight tensor for output normalization of shape (D,).

  • norm_out_bias (torch.Tensor) – Bias tensor for output normalization of shape (D,).

  • p_out_weight (torch.Tensor) – Weight tensor for output projection of shape (D, D).

  • g_out_weight (torch.Tensor) – Weight tensor for output gating of shape (D, D).

  • eps (float, optional) – Small constant for numerical stability in normalization. Defaults to 1e-5.

Returns:

Output tensor of shape (batch_size, seq_len, seq_len, hidden_dim)

Notes

  1. Context is saved for backward pass. You don’t need to save it manually.

  2. Kernel precision (fp32, bf16,fp16) is based on input dtypes. For tf32, set it from torch global scope

  3. Limitation Currently only supports hidden_dim values that are multiples of 32.

Example

>>> import torch
>>> from cuequivariance_torch import triangle_multiplicative_update
>>> if torch.cuda.is_available():
...     device = torch.device("cuda")
...     batch_size, seq_len, hidden_dim = 1, 128, 128
...     # Create input tensor
...     x = torch.randn(batch_size, seq_len, seq_len, hidden_dim, requires_grad=True, device=device)
...     # Create mask (1 for valid positions, 0 for masked)
...     mask = torch.ones(batch_size, seq_len, seq_len, device=device)
...     # Perform triangular multiplication
...     output = triangle_multiplicative_update(
...         x=x,
...         direction="outgoing",  # or "incoming"
...         mask=mask,
...     ) #If CUEQ_DISABLE_AOT_TUNING is not set to 1, will default to json config look-up if config is available. If not, then proceeds to auto-tuning using Ahead Of Time Compilation.
...     print(output.shape)  # torch.Size([1, 128, 128, 128])
...     # Create gradient tensor and perform backward pass
...     grad_out = torch.randn_like(output)
...     output.backward(grad_out)
...     # Access gradients
...     print(x.grad.shape)  # torch.Size([1, 128, 128, 128])
torch.Size([1, 128, 128, 128])
torch.Size([1, 128, 128, 128])