triangle_multiplicative_update#
- cuequivariance_jax.triangle_multiplicative_update(
- x,
- direction='outgoing',
- key=None,
- mask=None,
- norm_in_weight=None,
- norm_in_bias=None,
- p_in_weight=None,
- p_in_bias=None,
- g_in_weight=None,
- g_in_bias=None,
- norm_out_weight=None,
- norm_out_bias=None,
- p_out_weight=None,
- p_out_bias=None,
- g_out_weight=None,
- g_out_bias=None,
- eps=1e-05,
- precision=Precision.DEFAULT,
- fallback=None,
Apply triangle multiplicative update operation.
This function performs a triangle multiplicative update operation, which is a key component in the AlphaFold2 architecture. The operation consists of:
Input normalization and gating
Triangular projection (either outgoing or incoming)
Output normalization and gating
- Parameters:
x (jax.Array) – Input tensor of shape (…, N, N, D_in) where: - … represents arbitrary batch dimensions - N is the sequence length - D_in is the input hidden dimension
direction (str) – Direction of the triangular projection. Must be either “outgoing” or “incoming”.
key (jax.Array, optional) – JAX random key for weight initialization. Required if any weights are None.
mask (jax.Array, optional) – Optional mask tensor of shape (…, N, N) for masking the output. Must be broadcastable with the input tensor’s batch dimensions.
norm_in_weight (jax.Array, optional) – Weight tensor for input normalization of shape (D_in,). If None, initialized to ones.
norm_in_bias (jax.Array, optional) – Bias tensor for input normalization of shape (D_in,). If None, initialized to zeros.
p_in_weight (jax.Array, optional) – Weight tensor for input projection of shape (2*D_in, D_in). If None, initialized with LeCun normal distribution.
p_in_bias (jax.Array, optional) – Bias tensor for input projection of shape (2*D_in,). If None, no bias is applied to the input projection.
g_in_weight (jax.Array, optional) – Weight tensor for input gating of shape (2*D_in, D_in). If None, initialized with LeCun normal distribution.
g_in_bias (jax.Array, optional) – Bias tensor for input gating of shape (2*D_in,). If None, no bias is applied to the input gating.
norm_out_weight (jax.Array, optional) – Weight tensor for output normalization of shape (D_in,). If None, initialized to ones.
norm_out_bias (jax.Array, optional) – Bias tensor for output normalization of shape (D_in,). If None, initialized to zeros.
p_out_weight (jax.Array, optional) – Weight tensor for output projection of shape (D_out, D_in). If None, initialized with LeCun normal distribution.
p_out_bias (jax.Array, optional) – Bias tensor for output projection of shape (D_out,). If None, no bias is applied to the output projection.
g_out_weight (jax.Array, optional) – Weight tensor for output gating of shape (D_out, D_in). If None, initialized with LeCun normal distribution.
g_out_bias (jax.Array, optional) – Bias tensor for output gating of shape (D_out,). If None, no bias is applied to the output gating.
eps (float) – Small constant for numerical stability in normalization. Defaults to 1e-5.
precision (Precision) – Precision mode for matrix multiplications. Available options: - DEFAULT: Use default precision setting - TF32: Use TensorFloat-32 precision - TF32x3: Use TensorFloat-32 precision with 3x accumulation - IEEE: Use IEEE 754 precision
fallback (bool | None)
- Returns:
- Output tensor of shape (…, N, N, D_out) where D_out is determined by
the first dimension of g_out_weight. If g_out_weight is not provided, D_out equals D_in (the input hidden dimension).
- Return type:
Notes
Unlike PyTorch, JAX arrays are immutable, so weight initialization returns new arrays
If output weights are not provided, they are initialized with D_out = D_in (preserving input dimension)
If weights are not provided, they are initialized with appropriate values, but in practice you should pass learned parameters
Supports arbitrary batch dimensions through broadcasting
Example
>>> import jax >>> import jax.numpy as jnp >>> from cuequivariance_jax import triangle_multiplicative_update >>> # Create input tensor with arbitrary batch dimensions >>> key = jax.random.key(0) >>> key, subkey = jax.random.split(key) >>> batch_dim1, batch_dim2, seq_len, D_in = 2, 3, 128, 128 >>> x = jax.random.normal(subkey, (batch_dim1, batch_dim2, seq_len, seq_len, D_in), dtype=jnp.float32) >>> # Create mask (1 for valid positions, 0 for masked) >>> mask = jnp.ones((batch_dim1, batch_dim2, seq_len, seq_len)) >>> # Create weight parameters (in practice, these would be learned) >>> norm_in_weight = jnp.ones(D_in) >>> norm_in_bias = jnp.zeros(D_in) >>> # Optional bias parameters for projection and gating layers >>> p_in_bias = jnp.zeros(2 * D_in) # Optional input projection bias >>> g_in_bias = jnp.zeros(2 * D_in) # Optional input gating bias >>> p_out_bias = jnp.zeros(D_in) # Optional output projection bias (would be D_out if dimension changes) >>> g_out_bias = jnp.zeros(D_in) # Optional output gating bias (would be D_out if dimension changes) >>> # Initialize other weights using the key >>> key, subkey = jax.random.split(key) >>> # Perform triangular multiplication >>> output = triangle_multiplicative_update( ... x=x, ... direction="outgoing", # or "incoming" ... key=subkey, # Only needed if some weights are None ... mask=mask, ... norm_in_weight=norm_in_weight, ... norm_in_bias=norm_in_bias, ... p_in_bias=p_in_bias, # Can be None to skip bias ... g_in_bias=g_in_bias, # Can be None to skip bias ... p_out_bias=p_out_bias, # Can be None to skip bias ... g_out_bias=g_out_bias, # Can be None to skip bias ... # ... pass other weights or let them initialize ... ... ) >>> print(output.shape) (2, 3, 128, 128, 128) >>> # Example with dimension change: input 128 -> output 256 >>> g_out_weight_256 = jax.random.normal(jax.random.key(1), (256, 128)) >>> p_out_weight_256 = jax.random.normal(jax.random.key(2), (256, 128)) >>> key, subkey2 = jax.random.split(key) >>> output_256 = triangle_multiplicative_update( ... x=x, ... direction="outgoing", ... key=subkey2, # Key needed for other weight initialization ... g_out_weight=g_out_weight_256, ... p_out_weight=p_out_weight_256, ... ) >>> print(output_256.shape) (2, 3, 128, 128, 256)