indexed_linear#
- cuequivariance_jax.experimental.indexed_linear(
- poly,
- counts,
- w,
- x,
- math_dtype=None,
- method='indexed_linear',
Linear layer with different weights for different parts of the input.
- Parameters:
poly (SegmentedPolynomial) – The polynomial descriptor. Only works for descriptors of a linear layer.
counts (Array) – Number of elements in each partition. Shape (C,).
w (Array) – Weights of the linear layer. Shape (C, num_weights).
x (Array) – Input data. Shape (Z, num_inputs). Z is equal to the sum of counts.
math_dtype (dtype | None) – Data type for computational operations. If None, automatically determined from input types. Defaults to None.
method (str)
- Returns:
Output data. Shape (Z, num_outputs).
- Return type:
Examples
This example demonstrates using indexed_linear for a batch of inputs with different species:
>>> import jax >>> import jax.numpy as jnp >>> import cuequivariance as cue >>> import cuequivariance_jax as cuex >>> >>> # Define problem parameters >>> num_species_total = 3 # Total number of different species >>> batch_size = 10 # Number of samples in batch >>> input_dim = 8 # Input feature dimension >>> output_dim = 16 # Output feature dimension >>> dtype = jnp.float32 >>> >>> # Define how many elements belong to each species >>> num_species = jnp.array([3, 4, 3], dtype=jnp.int32) # Sum equals batch_size >>> >>> # Generate random input data >>> input_array = jax.random.normal(jax.random.key(0), (batch_size, input_dim), dtype) >>> >>> # Define irreps for input and output features >>> input_irreps = cue.Irreps(cue.O3, f"{input_dim}x0e") # Scalar features >>> output_irreps = cue.Irreps(cue.O3, f"{output_dim}x0e") # Scalar features >>> >>> # Create a linear descriptor >>> e = cue.descriptors.linear(input_irreps, output_irreps) >>> >>> # Generate weights for each species >>> w = jax.random.normal(jax.random.key(1), (num_species_total, e.inputs[0].dim), dtype) >>> >>> # Apply the indexed linear layer >>> result = cuex.experimental.indexed_linear(e.polynomial, num_species, w, input_array) >>> >>> # Verify output shape >>> assert result.shape == (batch_size, output_dim)