IrrepsIndexedLinear#

class cuequivariance_jax.nnx.IrrepsIndexedLinear#
__init__(
irreps_in,
irreps_out,
num_indices,
scale=1.0,
*,
name='indexed_linear',
dtype=<class 'jax.numpy.float32'>,
rngs,
)#
Parameters: