Release Notes#
Latest Changes#
0.5.1 (2025-06-18)#
This release includes improvements to triangle multiplicative update with torch.compile support and enhanced tuning configuration options.
Added#
[Torch]
torch.compilesupport forcuet.triangle_multiplicative_update[Torch] Optional precision argument for
cuet.triangle_multiplicative_update:precision (Precision, optional): Precision mode for matrix multiplications. If None, uses TF32 if enabled in PyTorch usingtorch.backends.cuda.matmul.allow_tf32, otherwise uses DEFAULT precision.Available options:
DEFAULT: Use default precision setting oftriton.language.dotTF32: Use TensorFloat-32 precisionTF32x3: Use TensorFloat-32 precision with 3x accumulationIEEE: Use IEEE 754 precision
Improved#
[Torch] Enhanced tuning configuration for
cuet.triangle_multiplicative_updatewith support for multi-process tuning. Our tuning modes:Quick testing: Default configuration where tuning configs, if existent, are looked-up. If not, then falls back to default kernel parameters. No tuning is performed.
On-Demand tuning: Set
CUEQ_TRITON_TUNING_MODE = "ONDEMAND"to auto-tune for new shapes encountered on first run (may take several minutes)AOT tuning: Set
CUEQ_TRITON_TUNING_MODE = "AOT"to perform full ahead-of-time tuning for optimal performance (may take several hours)Ignore cache: Set
CUEQ_TRITON_IGNORE_EXISTING_CACHEto ignore both the default settings that come with the package and any user-local settings previously saved with AOT/ONDEMAND tuning. May be used to regenerate optimal settings for a particular setup.Cache directory: Set
CUEQ_TRITON_CACHE_DIRto specify where tuning configurations are stored. Default location is${HOME}/.cache/cuequivariance-triton. Note: When running in containers where$HOMEis inside the container (typically/root), tuning changes may be lost on container restart unless the container is committed or a persistent cache directory is specified.
Fixed#
[Torch] Fixed torch.compile compatibility issues with triangle multiplicative update
[Torch] Tuning issues for
cuet.triangle_multiplicative_updatewith multiple processes.
Limitations#
PyTorch does not currently bundle the latest Triton version as pytorch-triton. As a result, Blackwell GPU users may occasionally experience hangs or instability during model execution. Users may attempt installation with the latest Triton from source at their own risk. We are monitoring this issue and will remedy as soon as possible.
[Torch] Tuning for
cuet.triangle_multiplicative_updateis always performed for GPU-0 and may not be the optimal setting for all GPUs in a heterogenous multi-GPU setting
0.5.0 (2025-06-10)#
This release introduces triangle_attention and triangle_multiplicative_update.
This is the last release with cuda11 support. In the next release we will drop cuda11.
Added#
[Torch] Add
cuet.triangle_attention[Torch] Add
cuet.triangle_multiplicative_update[JAX] Add
cuex.experimental.indexed_linear. Note that this function is not working with cuda11 because it requires cuBLAS 12.5.[Torch/JAX] Add argument
simplify_irreps3: bool = Falsetocue.descriptors.channelwise_tensor_product[Torch/JAX] Add method
permute_inputstoSegmentedPolynomial
Improved#
[Torch/JAX] In some settings, accelerate the CUDA kernel for uniform 1d segmented polynomials (like symmetric contraction and channelwise tensor product). While most operation speeds are unchanged, we observe up to 2x speedup in some cases.
Limitations#
PyTorch does not currently bundle the latest Triton version as pytorch-triton. As a result, Blackwell GPU users may occasionally experience hangs or instability during model execution. Users may attempt installation with the latest Triton from source at their own risk. We are monitoring this issue and will remedy as soon as possible.
Documentation#
cuet.triangle_multiplicative_update: Auto-tuning behavior can be controlled through environment variables:Default: Full Ahead-of-Time (AOT) auto-tuning enabled for optimal performance (may take several hours)
Quick testing: Set
CUEQ_DISABLE_AOT_TUNING = 1andCUEQ_DEFAULT_CONFIG = 1to disable all tuningOn-Demand tuning:
CUEQ_DISABLE_AOT_TUNING = 1, auto-tunes for new shapes encountered on first run. (may take several minutes)Note: When using Docker with default or on-demand tuning enabled, commit the container to persist tuning changes
Note: When running in a multi-GPU setup, we recommend setting
CUEQ_DISABLE_AOT_TUNING = 1andCUEQ_DEFAULT_CONFIG = 1.
0.4.0 (2025-04-25)#
This release introduces some changes to the API, it introduce the class cue.SegmentedPolynomial (and corresponding counterparts) which generalizes the notion of segmented tensor product by allowing to construct non-homogeneous polynomials.
Added#
[Torch]
cuet.SegmentedPolynomialmodule giving access to the indexing features of the uniform 1d kernel[Torch/JAX] Add full support for float16 and bfloat16
[Torch/JAX] Class
cue.SegmentedOperand[Torch/JAX] Class
cue.SegmentedPolynomial[Torch/JAX] Class
cue.EquivariantPolynomialthat contains acue.SegmentedPolynomialand thecue.Repof its inputs and outputs[Torch/JAX] Add caching for
cue.descriptor.symmetric_contraction[Torch/JAX] Add caching for
cue.SegmentedTensorProduct.symmetrize_operands[JAX] ARM config support
[JAX]
cuex.segmented_polynomialandcuex.equivariant_polynomial[JAX] Advanced Batching capabilities, each input/output of a segmented polynomial can have multiple axes and any of those can be indexed.
[JAX] Implementation of the Dead Code Elimination rule for the primitive
cuex.segmented_polynomial
Breaking Changes#
[Torch/JAX] Rename
SegmentedTensorProduct.flop_costtoflop[Torch/JAX] Rename
SegmentedTensorProduct.memory_costtomemory[Torch/JAX] Removed
IrrepsArrayin favor ofRepArray[Torch/JAX] Change folder structure of cuequivariance and cuequivariance-jax. Now the main subfolders are
segmented_polynomialsandgroup_theory[Torch/JAX] Deprecate
cue.EquivariantTensorProductin favor ofcue.EquivariantPolynomial. The later will have a limited list of features compared tocue.EquivariantTensorProduct. It does not containchange_layoutand the methods to move the operands. Please open an issue if you need any of the missing methods.[Torch/JAX] The descriptors return
cue.EquivariantPolynomialinstead ofcue.EquivariantTensorProduct[Torch/JAX] Change
cue.SegmentedPolynomial.canonicalize_subscriptsbehavior for coefficient subscripts. It transposes the coefficients to be ordered the same way as the rest of the subscripts.[Torch] To reduce the size of the so library, we removed support of math dtype fp32 when using IO dtype fp64 in the case of the fully connected tensor product. (It concerns
cuet.FullyConnectedTensorProductandcuet.FullyConnectedTensorProductConv). Please open an issue if you need this feature.
Fixed#
[Torch/JAX]
cue.SegmentedTensorProduct.sort_indices_for_identical_operandswas silently operating on STP with non scalar coefficient, now it will raise an error to say that this case is not implemented. We should implement it at some point.
0.3.0 (2025-03-05)#
The main changes are:
[JAX] New JIT Uniform 1d kernel with JAX bindings
Computes any polynomial based on 1d uniform STPs
Supports arbitrary derivatives
Provides optional fused scatter/gather for the inputs and outputs
🎉 We observed a ~3x speedup for MACE with cuEquivariance-JAX v0.3.0 compared to cuEquivariance-Torch v0.2.0 🎉
[Torch] Adds torch.compile support
[Torch] Beta limited Torch bindings to the new JIT Uniform 1d kernel
enable the new kernel by setting the environement variable
CUEQUIVARIANCE_OPS_USE_JIT=1
[Torch] Implements scatter/gather fusion through a beta API for Uniform 1d
this is a temporary API that will change,
cuequivariance_torch.primitives.tensor_product.TensorProductUniform4x1dIndexed
Breaking Changes#
[Torch/JAX] Removed
cue.TensorProductExecutionand addedcue.Operationwhich is more lightweight and better aligned with the backend.[JAX] In
cuex.equivariant_tensor_product, the argumentsdtype_mathanddtype_outputare renamed tomath_dtypeandoutput_dtyperespectively. This change adds consistency with the rest of the library.[JAX] In
cuex.equivariant_tensor_product, the argumentsalgorithm,precision,use_custom_primitiveanduse_custom_kernelshave been removed. This change avoids a proliferation of arguments that are not used in all implementations. An argumentimpl: strhas been added instead to select the implementation.[JAX] Removed
cuex.symmetric_tensor_product. Thecuex.tensor_productfunction now handles any non-homogeneous polynomials.[JAX] The batching support (
jax.vmap) ofcuex.equivariant_tensor_productis now limited to specific use cases.[JAX] The interface of
cuex.tensor_producthas changed. It now takes a list oftuple[cue.Operation, cue.SegmentedTensorProduct]instead of a singlecue.SegmentedTensorProduct. This change allowscuex.tensor_productto execute any type of non-homogeneous polynomials.[JAX] Removed
cuex.flax_linen.Linearto reduce maintenance burden. Usecue.descriptor.lineartogether withcuex.equivariant_tensor_productinstead.
e = cue.descriptors.linear(input.irreps, output_irreps)
w = self.param(name, jax.random.normal, (e.inputs[0].dim,), input.dtype)
output = cuex.equivariant_tensor_product(e, w, input)
Fixed#
[Torch/JAX] Fixed
cue.descriptor.full_tensor_productwhich was ignoring theirreps3_filterargument.[Torch/JAX] Fixed a rare bug with
np.bincountwhen using an old version of numpy. The input is now flattened to make it work with all versions.[Torch] Identified a bug in the CUDA kernel and disabled CUDA kernel for
cuet.TransposeSegmentsandcuet.TransposeIrrepsLayout.
Added#
[Torch/JAX] Added
__mul__tocue.EquivariantTensorProductto allow rescaling the equivariant tensor product.[JAX] Added JAX Bindings to the uniform 1d JIT kernel. This kernel handles any kind of non-homogeneous polynomials as long as the contraction pattern (subscripts) has only one mode. It handles batched/shared/indexed input/output. The indexed input/output is processed through atomic operations.
[JAX] Added an
indicesargument tocuex.equivariant_tensor_productandcuex.tensor_productto handle the scatter/gather fusion.[Torch] Beta limited Torch bindings to the new JIT Uniform 1d kernel (enable the new kernel by setting the environement variable
CUEQUIVARIANCE_OPS_USE_JIT=1)[Torch] Implements scatter/gather fusion through a beta API for Uniform 1d (this is a temporary API that will change,
cuequivariance_torch.primitives.tensor_product.TensorProductUniform4x1dIndexed)
0.2.0 (2025-01-24)#
Breaking Changes#
Minimal Python version is now 3.10 in all packages.
cuet.TensorProductandcuet.EquivariantTensorProductnow require inputs to be of shape(batch_size, dim)or(1, dim). Inputs of dimension(dim,)are no longer allowed.cuex.IrrepsArrayis now an alias forcuex.RepArray.cuex.RepArray.irrepsandcuex.RepArray.segmentsare no longer functions. They are now properties.cuex.IrrepsArray.is_simplehas been replaced bycuex.RepArray.is_irreps_array.The function
cuet.spherical_harmonicshas been replaced by the Torch Modulecuet.SphericalHarmonics. This change enables the use oftorch.jit.scriptandtorch.compile.
Added#
Added experimental support for
torch.compile. Known issue: the export in C++ is not working.Added
cue.IrrepsAndLayout: A simple class that inherits fromcue.Repand contains acue.Irrepsand acue.IrrepsLayout.Added
cuex.RepArrayfor representing an array of any kind of representations (not only irreps as was previously possible withcuex.IrrepsArray).
Fixed#
Added support for empty batch dimension in
cuet(cuequivariance_torch).Moved
README.mdandLICENSEinto the source distribution.Fixed
cue.SegmentedTensorProduct.flop_costfor the special case of 1 operand.
Improved#
Removed special case handling for degree 0 in
cuet.SymmetricTensorProduct.
0.1.0 (2024-11-18)#
Beta version of cuEquivariance released.