triangle_multiplicative_update#

cuequivariance_jax.triangle_multiplicative_update(
x,
direction='outgoing',
key=None,
mask=None,
norm_in_weight=None,
norm_in_bias=None,
p_in_weight=None,
p_in_bias=None,
g_in_weight=None,
g_in_bias=None,
norm_out_weight=None,
norm_out_bias=None,
p_out_weight=None,
p_out_bias=None,
g_out_weight=None,
g_out_bias=None,
eps=1e-05,
precision=Precision.DEFAULT,
fallback=None,
)#

Apply triangle multiplicative update operation.

This function performs a triangle multiplicative update operation, which is a key component in the AlphaFold2 architecture. The operation consists of:

  1. Input normalization and gating

  2. Triangular projection (either outgoing or incoming)

  3. Output normalization and gating

Parameters:
  • x (jax.Array) – Input tensor of shape (B, N, N, D) where: - B is the batch size - N is the sequence length - D is the hidden dimension Can also be 3D (N, N, D) which will be expanded to 4D.

  • direction (str) – Direction of the triangular projection. Must be either “outgoing” or “incoming”.

  • key (jax.Array, optional) – JAX random key for weight initialization. Required if any weights are None.

  • mask (jax.Array, optional) – Optional mask tensor of shape (B, N, N) for masking the output. Can also be 2D (N, N) which will be expanded to 3D.

  • norm_in_weight (jax.Array, optional) – Weight tensor for input normalization of shape (D,). If None, initialized to ones.

  • norm_in_bias (jax.Array, optional) – Bias tensor for input normalization of shape (D,). If None, initialized to zeros.

  • p_in_weight (jax.Array, optional) – Weight tensor for input projection of shape (2D, D). If None, initialized with LeCun normal distribution.

  • p_in_bias (jax.Array, optional) – Bias tensor for input projection of shape (2D,). If None, no bias is applied to the input projection.

  • g_in_weight (jax.Array, optional) – Weight tensor for input gating of shape (2D, D). If None, initialized with LeCun normal distribution.

  • g_in_bias (jax.Array, optional) – Bias tensor for input gating of shape (2D,). If None, no bias is applied to the input gating.

  • norm_out_weight (jax.Array, optional) – Weight tensor for output normalization of shape (D,). If None, initialized to ones.

  • norm_out_bias (jax.Array, optional) – Bias tensor for output normalization of shape (D,). If None, initialized to zeros.

  • p_out_weight (jax.Array, optional) – Weight tensor for output projection of shape (D, D). If None, initialized with LeCun normal distribution.

  • p_out_bias (jax.Array, optional) – Bias tensor for output projection of shape (D,). If None, no bias is applied to the output projection.

  • g_out_weight (jax.Array, optional) – Weight tensor for output gating of shape (D, D). If None, initialized with LeCun normal distribution.

  • g_out_bias (jax.Array, optional) – Bias tensor for output gating of shape (D,). If None, no bias is applied to the output gating.

  • eps (float) – Small constant for numerical stability in normalization. Defaults to 1e-5.

  • precision (Precision) – Precision mode for matrix multiplications. Available options: - DEFAULT: Use default precision setting - TF32: Use TensorFloat-32 precision - TF32x3: Use TensorFloat-32 precision with 3x accumulation - IEEE: Use IEEE 754 precision

  • fallback (bool | None)

Returns:

Output tensor of shape (B, N, N, D). Always returns 4D tensor even if input was 3D.

Return type:

jax.Array

Notes

  • Unlike PyTorch, JAX arrays are immutable, so weight initialization returns new arrays

  • Hidden dimension D must be divisible by 64 for the BND_BND layout in layer normalization

  • If weights are not provided, they are initialized with appropriate values, but in practice you should pass learned parameters

Example

>>> import jax
>>> import jax.numpy as jnp
>>> from cuequivariance_jax import triangle_multiplicative_update
>>> # Create input tensor
>>> key = jax.random.key(0)
>>> key, subkey = jax.random.split(key)
>>> batch_size, seq_len, hidden_dim = 1, 128, 128
>>> x = jax.random.normal(subkey, (batch_size, seq_len, seq_len, hidden_dim), dtype=jnp.float32)
>>> # Create mask (1 for valid positions, 0 for masked)
>>> mask = jnp.ones((batch_size, seq_len, seq_len))
>>> # Create weight parameters (in practice, these would be learned)
>>> norm_in_weight = jnp.ones(hidden_dim)
>>> norm_in_bias = jnp.zeros(hidden_dim)
>>> # Optional bias parameters for projection and gating layers
>>> p_in_bias = jnp.zeros(2 * hidden_dim)  # Optional input projection bias
>>> g_in_bias = jnp.zeros(2 * hidden_dim)  # Optional input gating bias
>>> p_out_bias = jnp.zeros(hidden_dim)     # Optional output projection bias
>>> g_out_bias = jnp.zeros(hidden_dim)     # Optional output gating bias
>>> # Initialize other weights using the key
>>> key, subkey = jax.random.split(key)
>>> # Perform triangular multiplication
>>> output = triangle_multiplicative_update(
...     x=x,
...     direction="outgoing",  # or "incoming"
...     key=subkey,  # Only needed if some weights are None
...     mask=mask,
...     norm_in_weight=norm_in_weight,
...     norm_in_bias=norm_in_bias,
...     p_in_bias=p_in_bias,  # Can be None to skip bias
...     g_in_bias=g_in_bias,  # Can be None to skip bias
...     p_out_bias=p_out_bias,  # Can be None to skip bias
...     g_out_bias=g_out_bias,  # Can be None to skip bias
...     # ... pass other weights or let them initialize ...
... )
>>> print(output.shape)
(1, 128, 128, 128)