Rotation#
- class cuequivariance_torch.Rotation#
A class that represents a rotation layer for SO3 or O3 representations.
- Parameters:
irreps (Irreps) – The irreducible representations of the tensor to rotate.
layout (IrrepsLayout, optional) – The memory layout of the tensor,
cue.ir_mulis preferred.layout_in (IrrepsLayout, optional) – The layout of the input irreducible representations, by default
layout.layout_out (IrrepsLayout, optional) – The layout of the output irreducible representations, by default
layout.device (torch.device, optional) – The device to use for the operation.
math_dtype (torch.dtype or string, optional) – The dtype to use for the math operations, by default it follows the dtype of the input tensors, if possible, or the torch default dtype (see SegmentedPolynomial for more details).
method (str, optional) – The method to use for the operation, by default “uniform_1d” (using a CUDA kernel) if all segments have the same shape, otherwise “naive” (using a PyTorch implementation).
use_fallback (bool, optional, deprecated) – Whether to use a “fallback” implementation, now maps to method: If True the “naive” method is used. If False the “uniform_1d” method is used (make sure all segments have the same shape).
Forward Pass
- forward(gamma, beta, alpha, x)#
Forward pass of the rotation layer.
- Parameters:
gamma (torch.Tensor) – The gamma angles. First rotation around the y-axis.
beta (torch.Tensor) – The beta angles. Second rotation around the x-axis.
alpha (torch.Tensor) – The alpha angles. Third rotation around the y-axis.
x (torch.Tensor) – The input tensor.
- Returns:
The rotated tensor.
- Return type:
- __init__(
- irreps,
- *,
- layout=None,
- layout_in=None,
- layout_out=None,
- device=None,
- math_dtype=None,
- use_fallback=None,
- method=None,
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- Parameters:
irreps (Irreps)
layout (IrrepsLayout | None)
layout_in (IrrepsLayout | None)
layout_out (IrrepsLayout | None)
device (device | None)
use_fallback (bool | None)
method (str | None)
- classmethod __new__(*args, **kwargs)#