Linear#

class cuequivariance_torch.Linear#

A class that represents an equivariant linear layer.

Parameters:
  • irreps_in (Irreps) – The input irreducible representations.

  • irreps_out (Irreps) – The output irreducible representations.

  • layout (IrrepsLayout, optional) – The layout of the irreducible representations, by default cue.mul_ir. This is the layout used in the e3nn library.

  • layout_in (IrrepsLayout, optional) – The layout of the input irreducible representations, by default layout.

  • layout_out (IrrepsLayout, optional) – The layout of the output irreducible representations, by default layout.

  • shared_weights (bool, optional) – Whether to use shared weights, by default True.

  • internal_weights (bool, optional) – Whether to use internal weights, by default True if shared_weights is True, otherwise False.

  • weight_classes (int, optional) – If provided, the weight tensor will have this as a batch dimension (expected if using external weights). If this is specified and >1, at forward time each batch element will use a slice of the weight tensor as indexed by the weight_indices.

  • device (torch.device, optional) – The device to use for the linear layer.

  • dtype (torch.dtype, optional) – The dtype to use for the linear layer weights, by default torch.float32.

  • math_dtype (torch.dtype, optional) – The dtype to use for the math operations, by default it follows the dtype of the input tensors, if possible, or the torch default dtype.

  • method (str, optional) – The method to use for the linear layer, by default “naive” (using a PyTorch implementation).

  • use_fallback (bool, optional, deprecated) – Whether to use a “fallback” implementation, now maps to method: If True or None (default), the “naive” method is used. If False, the “fused_tp” method is used.

Forward Pass

forward(x, weight=None, weight_indices=None)#

Forward pass of the linear layer.

Parameters:
  • x (torch.Tensor) – The input tensor.

  • weight (torch.Tensor, optional) – The weight tensor. If None, the internal weight tensor is used, otherwise: If weights are not shared, this should be a tensor of shape (batch_size, weight_numel). If weights are shared, this should be a tensor of shape (weight_classes, weight_numel) (where weight_classes is 1 if unspecified).

  • weight_indices (torch.Tensor, optional) – The indices of the weight tensor: if weight_classes > 1, this is an integer tensor of shape (batch_size,), indicating which weight slice to use for each batch element.

Returns:

The output tensor after applying the linear transformation.

Return type:

torch.Tensor

Raises:

ValueError – If internal weights are used and weight is not None, or if shared weights are used and weight is not a 1D tensor, or if shared weights are not used and weight is not a 2D tensor.

__init__(
irreps_in,
irreps_out,
*,
layout=None,
layout_in=None,
layout_out=None,
shared_weights=True,
internal_weights=None,
weight_classes=1,
device=None,
dtype=None,
math_dtype=None,
use_fallback=None,
method=None,
)#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
classmethod __new__(*args, **kwargs)#