indexed_linear#

cuequivariance_jax.experimental.indexed_linear(
poly,
counts,
w,
x,
math_dtype=None,
method='indexed_linear',
)#

Linear layer with different weights for different parts of the input.

Parameters:
  • poly (SegmentedPolynomial) – The polynomial descriptor. Only works for descriptors of a linear layer.

  • counts (Array) – Number of elements in each partition. Shape (C,).

  • w (Array) – Weights of the linear layer. Shape (C, num_weights).

  • x (Array) – Input data. Shape (Z, num_inputs). Z is equal to the sum of counts.

  • math_dtype (dtype | None) – Data type for computational operations. If None, automatically determined from input types. Defaults to None.

  • method (str)

Returns:

Output data. Shape (Z, num_outputs).

Return type:

Array

Examples

This example demonstrates using indexed_linear for a batch of inputs with different species:

>>> import jax
>>> import jax.numpy as jnp
>>> import cuequivariance as cue
>>> import cuequivariance_jax as cuex
>>>
>>> # Define problem parameters
>>> num_species_total = 3  # Total number of different species
>>> batch_size = 10        # Number of samples in batch
>>> input_dim = 8          # Input feature dimension
>>> output_dim = 16        # Output feature dimension
>>> dtype = jnp.float32
>>>
>>> # Define how many elements belong to each species
>>> num_species = jnp.array([3, 4, 3], dtype=jnp.int32)  # Sum equals batch_size
>>>
>>> # Generate random input data
>>> input_array = jax.random.normal(jax.random.key(0), (batch_size, input_dim), dtype)
>>>
>>> # Define irreps for input and output features
>>> input_irreps = cue.Irreps(cue.O3, f"{input_dim}x0e")   # Scalar features
>>> output_irreps = cue.Irreps(cue.O3, f"{output_dim}x0e") # Scalar features
>>>
>>> # Create a linear descriptor
>>> e = cue.descriptors.linear(input_irreps, output_irreps)
>>>
>>> # Generate weights for each species
>>> w = jax.random.normal(jax.random.key(1), (num_species_total, e.inputs[0].dim), dtype)
>>>
>>> # Apply the indexed linear layer
>>> result = cuex.experimental.indexed_linear(e.polynomial, num_species, w, input_array)
>>>
>>> # Verify output shape
>>> assert result.shape == (batch_size, output_dim)