triangle_multiplicative_update#

cuequivariance_jax.triangle_multiplicative_update(
x,
direction='outgoing',
key=None,
mask=None,
norm_in_weight=None,
norm_in_bias=None,
p_in_weight=None,
p_in_bias=None,
g_in_weight=None,
g_in_bias=None,
norm_out_weight=None,
norm_out_bias=None,
p_out_weight=None,
p_out_bias=None,
g_out_weight=None,
g_out_bias=None,
eps=1e-05,
precision=Precision.DEFAULT,
fallback=None,
)#

Apply triangle multiplicative update operation.

This function performs a triangle multiplicative update operation, which is a key component in the AlphaFold2 architecture. The operation consists of:

  1. Input normalization and gating

  2. Triangular projection (either outgoing or incoming)

  3. Output normalization and gating

Parameters:
  • x (jax.Array) – Input tensor of shape (…, N, N, D_in) where: - … represents arbitrary batch dimensions - N is the sequence length - D_in is the input hidden dimension

  • direction (str) – Direction of the triangular projection. Must be either “outgoing” or “incoming”.

  • key (jax.Array, optional) – JAX random key for weight initialization. Required if any weights are None.

  • mask (jax.Array, optional) – Optional mask tensor of shape (…, N, N) for masking the output. Must be broadcastable with the input tensor’s batch dimensions.

  • norm_in_weight (jax.Array, optional) – Weight tensor for input normalization of shape (D_in,). If None, initialized to ones.

  • norm_in_bias (jax.Array, optional) – Bias tensor for input normalization of shape (D_in,). If None, initialized to zeros.

  • p_in_weight (jax.Array, optional) – Weight tensor for input projection of shape (2*D_in, D_in). If None, initialized with LeCun normal distribution.

  • p_in_bias (jax.Array, optional) – Bias tensor for input projection of shape (2*D_in,). If None, no bias is applied to the input projection.

  • g_in_weight (jax.Array, optional) – Weight tensor for input gating of shape (2*D_in, D_in). If None, initialized with LeCun normal distribution.

  • g_in_bias (jax.Array, optional) – Bias tensor for input gating of shape (2*D_in,). If None, no bias is applied to the input gating.

  • norm_out_weight (jax.Array, optional) – Weight tensor for output normalization of shape (D_in,). If None, initialized to ones.

  • norm_out_bias (jax.Array, optional) – Bias tensor for output normalization of shape (D_in,). If None, initialized to zeros.

  • p_out_weight (jax.Array, optional) – Weight tensor for output projection of shape (D_out, D_in). If None, initialized with LeCun normal distribution.

  • p_out_bias (jax.Array, optional) – Bias tensor for output projection of shape (D_out,). If None, no bias is applied to the output projection.

  • g_out_weight (jax.Array, optional) – Weight tensor for output gating of shape (D_out, D_in). If None, initialized with LeCun normal distribution.

  • g_out_bias (jax.Array, optional) – Bias tensor for output gating of shape (D_out,). If None, no bias is applied to the output gating.

  • eps (float) – Small constant for numerical stability in normalization. Defaults to 1e-5.

  • precision (Precision) – Precision mode for matrix multiplications. Available options: - DEFAULT: Use default precision setting - TF32: Use TensorFloat-32 precision - TF32x3: Use TensorFloat-32 precision with 3x accumulation - IEEE: Use IEEE 754 precision

  • fallback (bool | None)

Returns:

Output tensor of shape (…, N, N, D_out) where D_out is determined by

the first dimension of g_out_weight. If g_out_weight is not provided, D_out equals D_in (the input hidden dimension).

Return type:

jax.Array

Notes

  • Unlike PyTorch, JAX arrays are immutable, so weight initialization returns new arrays

  • If output weights are not provided, they are initialized with D_out = D_in (preserving input dimension)

  • If weights are not provided, they are initialized with appropriate values, but in practice you should pass learned parameters

  • Supports arbitrary batch dimensions through broadcasting

Example

>>> import jax
>>> import jax.numpy as jnp
>>> from cuequivariance_jax import triangle_multiplicative_update
>>> # Create input tensor with arbitrary batch dimensions
>>> key = jax.random.key(0)
>>> key, subkey = jax.random.split(key)
>>> batch_dim1, batch_dim2, seq_len, D_in = 2, 3, 128, 128
>>> x = jax.random.normal(subkey, (batch_dim1, batch_dim2, seq_len, seq_len, D_in), dtype=jnp.float32)
>>> # Create mask (1 for valid positions, 0 for masked)
>>> mask = jnp.ones((batch_dim1, batch_dim2, seq_len, seq_len))
>>> # Create weight parameters (in practice, these would be learned)
>>> norm_in_weight = jnp.ones(D_in)
>>> norm_in_bias = jnp.zeros(D_in)
>>> # Optional bias parameters for projection and gating layers
>>> p_in_bias = jnp.zeros(2 * D_in)  # Optional input projection bias
>>> g_in_bias = jnp.zeros(2 * D_in)  # Optional input gating bias
>>> p_out_bias = jnp.zeros(D_in)     # Optional output projection bias (would be D_out if dimension changes)
>>> g_out_bias = jnp.zeros(D_in)     # Optional output gating bias (would be D_out if dimension changes)
>>> # Initialize other weights using the key
>>> key, subkey = jax.random.split(key)
>>> # Perform triangular multiplication
>>> output = triangle_multiplicative_update(
...     x=x,
...     direction="outgoing",  # or "incoming"
...     key=subkey,  # Only needed if some weights are None
...     mask=mask,
...     norm_in_weight=norm_in_weight,
...     norm_in_bias=norm_in_bias,
...     p_in_bias=p_in_bias,  # Can be None to skip bias
...     g_in_bias=g_in_bias,  # Can be None to skip bias
...     p_out_bias=p_out_bias,  # Can be None to skip bias
...     g_out_bias=g_out_bias,  # Can be None to skip bias
...     # ... pass other weights or let them initialize ...
... )
>>> print(output.shape)
(2, 3, 128, 128, 128)
>>> # Example with dimension change: input 128 -> output 256
>>> g_out_weight_256 = jax.random.normal(jax.random.key(1), (256, 128))
>>> p_out_weight_256 = jax.random.normal(jax.random.key(2), (256, 128))
>>> key, subkey2 = jax.random.split(key)
>>> output_256 = triangle_multiplicative_update(
...     x=x,
...     direction="outgoing",
...     key=subkey2,  # Key needed for other weight initialization
...     g_out_weight=g_out_weight_256,
...     p_out_weight=p_out_weight_256,
... )
>>> print(output_256.shape)
(2, 3, 128, 128, 256)