triangle_multiplicative_update#
- cuequivariance_jax.triangle_multiplicative_update(
- x,
- direction='outgoing',
- key=None,
- mask=None,
- norm_in_weight=None,
- norm_in_bias=None,
- p_in_weight=None,
- p_in_bias=None,
- g_in_weight=None,
- g_in_bias=None,
- norm_out_weight=None,
- norm_out_bias=None,
- p_out_weight=None,
- p_out_bias=None,
- g_out_weight=None,
- g_out_bias=None,
- eps=1e-05,
- precision=Precision.DEFAULT,
- fallback=None,
Apply triangle multiplicative update operation.
This function performs a triangle multiplicative update operation, which is a key component in the AlphaFold2 architecture. The operation consists of:
Input normalization and gating
Triangular projection (either outgoing or incoming)
Output normalization and gating
- Parameters:
x (jax.Array) – Input tensor of shape (B, N, N, D) where: - B is the batch size - N is the sequence length - D is the hidden dimension Can also be 3D (N, N, D) which will be expanded to 4D.
direction (str) – Direction of the triangular projection. Must be either “outgoing” or “incoming”.
key (jax.Array, optional) – JAX random key for weight initialization. Required if any weights are None.
mask (jax.Array, optional) – Optional mask tensor of shape (B, N, N) for masking the output. Can also be 2D (N, N) which will be expanded to 3D.
norm_in_weight (jax.Array, optional) – Weight tensor for input normalization of shape (D,). If None, initialized to ones.
norm_in_bias (jax.Array, optional) – Bias tensor for input normalization of shape (D,). If None, initialized to zeros.
p_in_weight (jax.Array, optional) – Weight tensor for input projection of shape (2D, D). If None, initialized with LeCun normal distribution.
p_in_bias (jax.Array, optional) – Bias tensor for input projection of shape (2D,). If None, no bias is applied to the input projection.
g_in_weight (jax.Array, optional) – Weight tensor for input gating of shape (2D, D). If None, initialized with LeCun normal distribution.
g_in_bias (jax.Array, optional) – Bias tensor for input gating of shape (2D,). If None, no bias is applied to the input gating.
norm_out_weight (jax.Array, optional) – Weight tensor for output normalization of shape (D,). If None, initialized to ones.
norm_out_bias (jax.Array, optional) – Bias tensor for output normalization of shape (D,). If None, initialized to zeros.
p_out_weight (jax.Array, optional) – Weight tensor for output projection of shape (D, D). If None, initialized with LeCun normal distribution.
p_out_bias (jax.Array, optional) – Bias tensor for output projection of shape (D,). If None, no bias is applied to the output projection.
g_out_weight (jax.Array, optional) – Weight tensor for output gating of shape (D, D). If None, initialized with LeCun normal distribution.
g_out_bias (jax.Array, optional) – Bias tensor for output gating of shape (D,). If None, no bias is applied to the output gating.
eps (float) – Small constant for numerical stability in normalization. Defaults to 1e-5.
precision (Precision) – Precision mode for matrix multiplications. Available options: - DEFAULT: Use default precision setting - TF32: Use TensorFloat-32 precision - TF32x3: Use TensorFloat-32 precision with 3x accumulation - IEEE: Use IEEE 754 precision
fallback (bool | None)
- Returns:
Output tensor of shape (B, N, N, D). Always returns 4D tensor even if input was 3D.
- Return type:
Notes
Unlike PyTorch, JAX arrays are immutable, so weight initialization returns new arrays
Hidden dimension D must be divisible by 64 for the BND_BND layout in layer normalization
If weights are not provided, they are initialized with appropriate values, but in practice you should pass learned parameters
Example
>>> import jax >>> import jax.numpy as jnp >>> from cuequivariance_jax import triangle_multiplicative_update >>> # Create input tensor >>> key = jax.random.key(0) >>> key, subkey = jax.random.split(key) >>> batch_size, seq_len, hidden_dim = 1, 128, 128 >>> x = jax.random.normal(subkey, (batch_size, seq_len, seq_len, hidden_dim), dtype=jnp.float32) >>> # Create mask (1 for valid positions, 0 for masked) >>> mask = jnp.ones((batch_size, seq_len, seq_len)) >>> # Create weight parameters (in practice, these would be learned) >>> norm_in_weight = jnp.ones(hidden_dim) >>> norm_in_bias = jnp.zeros(hidden_dim) >>> # Optional bias parameters for projection and gating layers >>> p_in_bias = jnp.zeros(2 * hidden_dim) # Optional input projection bias >>> g_in_bias = jnp.zeros(2 * hidden_dim) # Optional input gating bias >>> p_out_bias = jnp.zeros(hidden_dim) # Optional output projection bias >>> g_out_bias = jnp.zeros(hidden_dim) # Optional output gating bias >>> # Initialize other weights using the key >>> key, subkey = jax.random.split(key) >>> # Perform triangular multiplication >>> output = triangle_multiplicative_update( ... x=x, ... direction="outgoing", # or "incoming" ... key=subkey, # Only needed if some weights are None ... mask=mask, ... norm_in_weight=norm_in_weight, ... norm_in_bias=norm_in_bias, ... p_in_bias=p_in_bias, # Can be None to skip bias ... g_in_bias=g_in_bias, # Can be None to skip bias ... p_out_bias=p_out_bias, # Can be None to skip bias ... g_out_bias=g_out_bias, # Can be None to skip bias ... # ... pass other weights or let them initialize ... ... ) >>> print(output.shape) (1, 128, 128, 128)