Release Notes#
Latest Changes#
0.4.0 (2025-04-25)#
This release introduces some changes to the API, it introduce the class cue.SegmentedPolynomial
(and corresponding counterparts) which generalizes the notion of segmented tensor product by allowing to construct non-homogeneous polynomials.
Added#
[Torch]
cuet.SegmentedPolynomial
module giving access to the indexing features of the uniform 1d kernel[Torch/JAX] Add full support for float16 and bfloat16
[Torch/JAX] Class
cue.SegmentedOperand
[Torch/JAX] Class
cue.SegmentedPolynomial
[Torch/JAX] Class
cue.EquivariantPolynomial
that contains acue.SegmentedPolynomial
and thecue.Rep
of its inputs and outputs[Torch/JAX] Add caching for
cue.descriptor.symmetric_contraction
[Torch/JAX] Add caching for
cue.SegmentedTensorProduct.symmetrize_operands
[JAX] ARM config support
[JAX]
cuex.segmented_polynomial
andcuex.equivariant_polynomial
[JAX] Advanced Batching capabilities, each input/output of a segmented polynomial can have multiple axes and any of those can be indexed.
[JAX] Implementation of the Dead Code Elimination rule for the primitive
cuex.segmented_polynomial
Breaking Changes#
[Torch/JAX] Rename
SegmentedTensorProduct.flop_cost
toflop
[Torch/JAX] Rename
SegmentedTensorProduct.memory_cost
tomemory
[Torch/JAX] Removed
IrrepsArray
in favor ofRepArray
[Torch/JAX] Change folder structure of cuequivariance and cuequivariance-jax. Now the main subfolders are
segmented_polynomials
andgroup_theory
[Torch/JAX] Deprecate
cue.EquivariantTensorProduct
in favor ofcue.EquivariantPolynomial
. The later will have a limited list of features compared tocue.EquivariantTensorProduct
. It does not containchange_layout
and the methods to move the operands. Please open an issue if you need any of the missing methods.[Torch/JAX] The descriptors return
cue.EquivariantPolynomial
instead ofcue.EquivariantTensorProduct
[Torch/JAX] Change
cue.SegmentedPolynomial.canonicalize_subscripts
behavior for coefficient subscripts. It transposes the coefficients to be ordered the same way as the rest of the subscripts.[Torch] To reduce the size of the so library, we removed support of math dtype fp32 when using IO dtype fp64 in the case of the fully connected tensor product. (It concerns
cuet.FullyConnectedTensorProduct
andcuet.FullyConnectedTensorProductConv
). Please open an issue if you need this feature.
Fixed#
[Torch/JAX]
cue.SegmentedTensorProduct.sort_indices_for_identical_operands
was silently operating on STP with non scalar coefficient, now it will raise an error to say that this case is not implemented. We should implement it at some point.
0.3.0 (2025-03-05)#
The main changes are:
[JAX] New JIT Uniform 1d kernel with JAX bindings
Computes any polynomial based on 1d uniform STPs
Supports arbitrary derivatives
Provides optional fused scatter/gather for the inputs and outputs
🎉 We observed a ~3x speedup for MACE with cuEquivariance-JAX v0.3.0 compared to cuEquivariance-Torch v0.2.0 🎉
[Torch] Adds torch.compile support
[Torch] Beta limited Torch bindings to the new JIT Uniform 1d kernel
enable the new kernel by setting the environement variable
CUEQUIVARIANCE_OPS_USE_JIT=1
[Torch] Implements scatter/gather fusion through a beta API for Uniform 1d
this is a temporary API that will change,
cuequivariance_torch.primitives.tensor_product.TensorProductUniform4x1dIndexed
Breaking Changes#
[Torch/JAX] Removed
cue.TensorProductExecution
and addedcue.Operation
which is more lightweight and better aligned with the backend.[JAX] In
cuex.equivariant_tensor_product
, the argumentsdtype_math
anddtype_output
are renamed tomath_dtype
andoutput_dtype
respectively. This change adds consistency with the rest of the library.[JAX] In
cuex.equivariant_tensor_product
, the argumentsalgorithm
,precision
,use_custom_primitive
anduse_custom_kernels
have been removed. This change avoids a proliferation of arguments that are not used in all implementations. An argumentimpl: str
has been added instead to select the implementation.[JAX] Removed
cuex.symmetric_tensor_product
. Thecuex.tensor_product
function now handles any non-homogeneous polynomials.[JAX] The batching support (
jax.vmap
) ofcuex.equivariant_tensor_product
is now limited to specific use cases.[JAX] The interface of
cuex.tensor_product
has changed. It now takes a list oftuple[cue.Operation, cue.SegmentedTensorProduct]
instead of a singlecue.SegmentedTensorProduct
. This change allowscuex.tensor_product
to execute any type of non-homogeneous polynomials.[JAX] Removed
cuex.flax_linen.Linear
to reduce maintenance burden. Usecue.descriptor.linear
together withcuex.equivariant_tensor_product
instead.
e = cue.descriptors.linear(input.irreps, output_irreps)
w = self.param(name, jax.random.normal, (e.inputs[0].dim,), input.dtype)
output = cuex.equivariant_tensor_product(e, w, input)
Fixed#
[Torch/JAX] Fixed
cue.descriptor.full_tensor_product
which was ignoring theirreps3_filter
argument.[Torch/JAX] Fixed a rare bug with
np.bincount
when using an old version of numpy. The input is now flattened to make it work with all versions.[Torch] Identified a bug in the CUDA kernel and disabled CUDA kernel for
cuet.TransposeSegments
andcuet.TransposeIrrepsLayout
.
Added#
[Torch/JAX] Added
__mul__
tocue.EquivariantTensorProduct
to allow rescaling the equivariant tensor product.[JAX] Added JAX Bindings to the uniform 1d JIT kernel. This kernel handles any kind of non-homogeneous polynomials as long as the contraction pattern (subscripts) has only one mode. It handles batched/shared/indexed input/output. The indexed input/output is processed through atomic operations.
[JAX] Added an
indices
argument tocuex.equivariant_tensor_product
andcuex.tensor_product
to handle the scatter/gather fusion.[Torch] Beta limited Torch bindings to the new JIT Uniform 1d kernel (enable the new kernel by setting the environement variable
CUEQUIVARIANCE_OPS_USE_JIT=1
)[Torch] Implements scatter/gather fusion through a beta API for Uniform 1d (this is a temporary API that will change,
cuequivariance_torch.primitives.tensor_product.TensorProductUniform4x1dIndexed
)
0.2.0 (2025-01-24)#
Breaking Changes#
Minimal Python version is now 3.10 in all packages.
cuet.TensorProduct
andcuet.EquivariantTensorProduct
now require inputs to be of shape(batch_size, dim)
or(1, dim)
. Inputs of dimension(dim,)
are no longer allowed.cuex.IrrepsArray
is now an alias forcuex.RepArray
.cuex.RepArray.irreps
andcuex.RepArray.segments
are no longer functions. They are now properties.cuex.IrrepsArray.is_simple
has been replaced bycuex.RepArray.is_irreps_array
.The function
cuet.spherical_harmonics
has been replaced by the Torch Modulecuet.SphericalHarmonics
. This change enables the use oftorch.jit.script
andtorch.compile
.
Added#
Added experimental support for
torch.compile
. Known issue: the export in C++ is not working.Added
cue.IrrepsAndLayout
: A simple class that inherits fromcue.Rep
and contains acue.Irreps
and acue.IrrepsLayout
.Added
cuex.RepArray
for representing an array of any kind of representations (not only irreps as was previously possible withcuex.IrrepsArray
).
Fixed#
Added support for empty batch dimension in
cuet
(cuequivariance_torch
).Moved
README.md
andLICENSE
into the source distribution.Fixed
cue.SegmentedTensorProduct.flop_cost
for the special case of 1 operand.
Improved#
Removed special case handling for degree 0 in
cuet.SymmetricTensorProduct
.
0.1.0 (2024-11-18)#
Beta version of cuEquivariance released.