Rotation#

class cuequivariance_torch.Rotation(
irreps: Irreps,
*,
layout: IrrepsLayout | None = None,
layout_in: IrrepsLayout | None = None,
layout_out: IrrepsLayout | None = None,
device: device | None = None,
math_dtype: dtype | None = None,
use_fallback: bool | None = None,
)#

A class that represents a rotation layer for SO3 or O3 representations.

Parameters:
  • irreps (Irreps) – The irreducible representations of the tensor to rotate.

  • layout (IrrepsLayout, optional) – The memory layout of the tensor, cue.ir_mul is preferred.

Forward Pass

forward(
gamma: Tensor,
beta: Tensor,
alpha: Tensor,
x: Tensor,
) Tensor#

Forward pass of the rotation layer.

Parameters:
  • gamma (torch.Tensor) – The gamma angles. First rotation around the y-axis.

  • beta (torch.Tensor) – The beta angles. Second rotation around the x-axis.

  • alpha (torch.Tensor) – The alpha angles. Third rotation around the y-axis.

  • x (torch.Tensor) – The input tensor.

Returns:

The rotated tensor.

Return type:

torch.Tensor