triangle_multiplicative_update#

cuequivariance_jax.triangle_multiplicative_update(
x,
direction='outgoing',
key=None,
mask=None,
norm_in_weight=None,
norm_in_bias=None,
p_in_weight=None,
p_in_bias=None,
g_in_weight=None,
g_in_bias=None,
norm_out_weight=None,
norm_out_bias=None,
p_out_weight=None,
p_out_bias=None,
g_out_weight=None,
g_out_bias=None,
eps=1e-05,
precision=Precision.DEFAULT,
fallback=None,
)#

Apply triangle multiplicative update operation.

This function performs a triangle multiplicative update operation, which is a key component in the AlphaFold2 architecture. The operation consists of:

  1. Input normalization and gating

  2. Triangular projection (either outgoing or incoming)

  3. Output normalization and gating

Parameters:
  • x (jax.Array) – Input tensor of shape (…, N, N, D_in) where: - … represents arbitrary batch dimensions - N is the sequence length - D_in is the input hidden dimension

  • direction (str) – Direction of the triangular projection. Must be either “outgoing” or “incoming”.

  • key (jax.Array, optional) – JAX random key for weight initialization. Required if any weights are None.

  • mask (jax.Array, optional) – Optional mask tensor of shape (…, N, N) for masking the output. Must be broadcastable with the input tensor’s batch dimensions.

  • norm_in_weight (jax.Array, optional) – Weight tensor for input normalization of shape (D_in,). If None, initialized to ones.

  • norm_in_bias (jax.Array, optional) – Bias tensor for input normalization of shape (D_in,). If None, initialized to zeros.

  • p_in_weight (jax.Array, optional) – Weight tensor for input projection of shape (2*D_in, D_in). If None, initialized with LeCun normal distribution.

  • p_in_bias (jax.Array, optional) – Bias tensor for input projection of shape (2*D_in,). If None, no bias is applied to the input projection.

  • g_in_weight (jax.Array, optional) – Weight tensor for input gating of shape (2*D_in, D_in). If None, initialized with LeCun normal distribution.

  • g_in_bias (jax.Array, optional) – Bias tensor for input gating of shape (2*D_in,). If None, no bias is applied to the input gating.

  • norm_out_weight (jax.Array, optional) – Weight tensor for output normalization of shape (D_in,). If None, initialized to ones.

  • norm_out_bias (jax.Array, optional) – Bias tensor for output normalization of shape (D_in,). If None, initialized to zeros.

  • p_out_weight (jax.Array, optional) – Weight tensor for output projection of shape (D_out, D_in). If None, initialized with LeCun normal distribution.

  • p_out_bias (jax.Array, optional) – Bias tensor for output projection of shape (D_out,). If None, no bias is applied to the output projection.

  • g_out_weight (jax.Array, optional) – Weight tensor for output gating of shape (D_out, D_in). If None, initialized with LeCun normal distribution.

  • g_out_bias (jax.Array, optional) – Bias tensor for output gating of shape (D_out,). If None, no bias is applied to the output gating.

  • eps (float) – Small constant for numerical stability in normalization. Defaults to 1e-5.

  • precision (Precision) – Precision mode for matrix multiplications. Available options: - DEFAULT: Use default precision setting - TF32: Use TensorFloat-32 precision - TF32x3: Use TensorFloat-32 precision with 3x accumulation - IEEE: Use IEEE 754 precision

  • fallback (bool | None)

Returns:

Output tensor of shape (…, N, N, D_out) where D_out is determined by

the first dimension of g_out_weight. If g_out_weight is not provided, D_out equals D_in (the input hidden dimension).

Return type:

jax.Array

Notes

  • Unlike PyTorch, JAX arrays are immutable, so weight initialization returns new arrays

  • If output weights are not provided, they are initialized with D_out = D_in (preserving input dimension)

  • If weights are not provided, they are initialized with appropriate values, but in practice you should pass learned parameters

  • Supports arbitrary batch dimensions through broadcasting

Example

>>> import jax
>>> import jax.numpy as jnp
>>> from cuequivariance_jax import triangle_multiplicative_update
>>> # Create input tensor with arbitrary batch dimensions
>>> key = jax.random.key(0)
>>> key, subkey = jax.random.split(key)
>>> batch_dim1, batch_dim2, seq_len, D_in = 2, 3, 128, 128
>>> x = jax.random.normal(subkey, (batch_dim1, batch_dim2, seq_len, seq_len, D_in), dtype=jnp.float32)
>>> # Create mask (1 for valid positions, 0 for masked)
>>> mask = jnp.ones((batch_dim1, batch_dim2, seq_len, seq_len))
>>> # Create weight parameters (in practice, these would be learned)
>>> norm_in_weight = jnp.ones(D_in)
>>> norm_in_bias = jnp.zeros(D_in)
>>> # Optional bias parameters for projection and gating layers
>>> p_in_bias = jnp.zeros(2 * D_in)  # Optional input projection bias
>>> g_in_bias = jnp.zeros(2 * D_in)  # Optional input gating bias
>>> p_out_bias = jnp.zeros(D_in)     # Optional output projection bias (would be D_out if dimension changes)
>>> g_out_bias = jnp.zeros(D_in)     # Optional output gating bias (would be D_out if dimension changes)
>>> # Initialize other weights using the key
>>> key, subkey = jax.random.split(key)
>>> # Perform triangular multiplication
>>> output = triangle_multiplicative_update(
...     x=x,
...     direction="outgoing",  # or "incoming"
...     key=subkey,  # Only needed if some weights are None
...     mask=mask,
...     norm_in_weight=norm_in_weight,
...     norm_in_bias=norm_in_bias,
...     p_in_bias=p_in_bias,  # Can be None to skip bias
...     g_in_bias=g_in_bias,  # Can be None to skip bias
...     p_out_bias=p_out_bias,  # Can be None to skip bias
...     g_out_bias=g_out_bias,  # Can be None to skip bias
...     # ... pass other weights or let them initialize ...
... )
>>> print(output.shape)
(2, 3, 128, 128, 128)
>>> # Example with dimension change: input 128 -> output 256
>>> g_out_weight_256 = jax.random.normal(jax.random.key(1), (256, 128))
>>> p_out_weight_256 = jax.random.normal(jax.random.key(2), (256, 128))
>>> key, subkey2 = jax.random.split(key)
>>> output_256 = triangle_multiplicative_update(
...     x=x,
...     direction="outgoing",
...     key=subkey2,  # Key needed for other weight initialization
...     g_out_weight=g_out_weight_256,
...     p_out_weight=p_out_weight_256,
... )
>>> print(output_256.shape)
(2, 3, 128, 128, 256)

Note

This operation uses a custom CUDA kernel for performance. When using this function on multiple devices, manual sharding is required to achieve proper performance. Without explicit sharding, performance will be significantly degraded. See JAX shard_map documentation for details on manual parallelism.