SphericalHarmonics#

class cuequivariance_torch.SphericalHarmonics(
ls: list[int],
normalize: bool = True,
device: device | None = None,
math_dtype: dtype | None = None,
use_fallback: bool | None = None,
)#

Compute the spherical harmonics of the input vectors as a torch module.

Forward Pass

forward(vectors: Tensor) Tensor#
Parameters:

vectors (torch.Tensor) – Input vectors of shape (batch, 3).

Returns:

The spherical harmonics of the input vectors of shape (batch, dim), where dim is the sum of 2*l+1 for l in ls.

Return type:

torch.Tensor