Release Notes#
Latest Changes#
0.3.0 (2025-03-05)#
The main changes are:
[JAX] New JIT Uniform 1d kernel with JAX bindings
Computes any polynomial based on 1d uniform STPs
Supports arbitrary derivatives
Provides optional fused scatter/gather for the inputs and outputs
🎉 We observed a ~3x speedup for MACE with cuEquivariance-JAX v0.3.0 compared to cuEquivariance-Torch v0.2.0 🎉
[Torch] Adds torch.compile support
[Torch] Beta limited Torch bindings to the new JIT Uniform 1d kernel
enable the new kernel by setting the environement variable
CUEQUIVARIANCE_OPS_USE_JIT=1
[Torch] Implements scatter/gather fusion through a beta API for Uniform 1d
this is a temporary API that will change,
cuequivariance_torch.primitives.tensor_product.TensorProductUniform4x1dIndexed
Breaking Changes#
[Torch/JAX] Removed
cue.TensorProductExecutionand addedcue.Operationwhich is more lightweight and better aligned with the backend.[JAX] In
cuex.equivariant_tensor_product, the argumentsdtype_mathanddtype_outputare renamed tomath_dtypeandoutput_dtyperespectively. This change adds consistency with the rest of the library.[JAX] In
cuex.equivariant_tensor_product, the argumentsalgorithm,precision,use_custom_primitiveanduse_custom_kernelshave been removed. This change avoids a proliferation of arguments that are not used in all implementations. An argumentimpl: strhas been added instead to select the implementation.[JAX] Removed
cuex.symmetric_tensor_product. Thecuex.tensor_productfunction now handles any non-homogeneous polynomials.[JAX] The batching support (
jax.vmap) ofcuex.equivariant_tensor_productis now limited to specific use cases.[JAX] The interface of
cuex.tensor_producthas changed. It now takes a list oftuple[cue.Operation, cue.SegmentedTensorProduct]instead of a singlecue.SegmentedTensorProduct. This change allowscuex.tensor_productto execute any type of non-homogeneous polynomials.[JAX] Removed
cuex.flax_linen.Linearto reduce maintenance burden. Usecue.descriptor.lineartogether withcuex.equivariant_tensor_productinstead.
e = cue.descriptors.linear(input.irreps, output_irreps)
w = self.param(name, jax.random.normal, (e.inputs[0].dim,), input.dtype)
output = cuex.equivariant_tensor_product(e, w, input)
Fixed#
[Torch/JAX] Fixed
cue.descriptor.full_tensor_productwhich was ignoring theirreps3_filterargument.[Torch/JAX] Fixed a rare bug with
np.bincountwhen using an old version of numpy. The input is now flattened to make it work with all versions.[Torch] Identified a bug in the CUDA kernel and disabled CUDA kernel for
cuet.TransposeSegmentsandcuet.TransposeIrrepsLayout.
Added#
[Torch/JAX] Added
__mul__tocue.EquivariantTensorProductto allow rescaling the equivariant tensor product.[JAX] Added JAX Bindings to the uniform 1d JIT kernel. This kernel handles any kind of non-homogeneous polynomials as long as the contraction pattern (subscripts) has only one mode. It handles batched/shared/indexed input/output. The indexed input/output is processed through atomic operations.
[JAX] Added an
indicesargument tocuex.equivariant_tensor_productandcuex.tensor_productto handle the scatter/gather fusion.[Torch] Beta limited Torch bindings to the new JIT Uniform 1d kernel (enable the new kernel by setting the environement variable
CUEQUIVARIANCE_OPS_USE_JIT=1)[Torch] Implements scatter/gather fusion through a beta API for Uniform 1d (this is a temporary API that will change,
cuequivariance_torch.primitives.tensor_product.TensorProductUniform4x1dIndexed)
0.2.0 (2025-01-24)#
Breaking Changes#
Minimal Python version is now 3.10 in all packages.
cuet.TensorProductandcuet.EquivariantTensorProductnow require inputs to be of shape(batch_size, dim)or(1, dim). Inputs of dimension(dim,)are no longer allowed.cuex.IrrepsArrayis now an alias forcuex.RepArray.cuex.RepArray.irrepsandcuex.RepArray.segmentsare no longer functions. They are now properties.cuex.IrrepsArray.is_simplehas been replaced bycuex.RepArray.is_irreps_array.The function
cuet.spherical_harmonicshas been replaced by the Torch Modulecuet.SphericalHarmonics. This change enables the use oftorch.jit.scriptandtorch.compile.
Added#
Added experimental support for
torch.compile. Known issue: the export in C++ is not working.Added
cue.IrrepsAndLayout: A simple class that inherits fromcue.Repand contains acue.Irrepsand acue.IrrepsLayout.Added
cuex.RepArrayfor representing an array of any kind of representations (not only irreps as was previously possible withcuex.IrrepsArray).
Fixed#
Added support for empty batch dimension in
cuet(cuequivariance_torch).Moved
README.mdandLICENSEinto the source distribution.Fixed
cue.SegmentedTensorProduct.flop_costfor the special case of 1 operand.
Improved#
Removed special case handling for degree 0 in
cuet.SymmetricTensorProduct.
0.1.0 (2024-11-18)#
Beta version of cuEquivariance released.