FullyConnectedTensorProductConv#

class cuequivariance_torch.layers.FullyConnectedTensorProductConv(
in_irreps: Irreps,
sh_irreps: Irreps,
out_irreps: Irreps,
batch_norm: bool = True,
mlp_channels: Sequence[int] | None = None,
mlp_activation: Module | Sequence[Module] | None = GELU(approximate='none'),
layout: IrrepsLayout = None,
use_fallback: bool | None = None,
)#

Message passing layer for tensor products in DiffDock-like architectures. The left operand of tensor product is the node features; the right operand consists of the spherical harmonic of edge vector.

Mathematical formulation:

\[\sum_{b \in \mathcal{N}_a} \mathbf{h}_b \otimes_{\psi_{a b}} Y\left(\hat{r}_{a b}\right)\]

where the path weights \(\psi_{a b}\) can be constructed from edge embeddings and scalar features using an MLP:

\[\psi_{a b} = \operatorname{MLP} \left(e_{a b}, \mathbf{h}_a^0, \mathbf{h}_b^0\right)\]

Users have the option to either directly input the weights or provide the MLP parameters and scalar features from edges and nodes.

Parameters:
  • in_irreps (Irreps) – Irreps for the input node features.

  • sh_irreps (Irreps) – Irreps for the spherical harmonic representations of edge vectors.

  • out_irreps (Irreps) – Irreps for the output.

  • batch_norm (bool, optional) – If true, batch normalization is applied. Defaults to True.

  • mlp_channels (Sequence of int, optional) – A sequence of integers defining the number of neurons in each layer in MLP before the output layer. If None, no MLP will be added. The input layer contains edge embeddings and node scalar features. Defaults to None.

  • mlp_activation (nn.Module or Sequence of nn.Module, optional) – A sequence of functions to be applied in between linear layers in MLP, e.g., nn.Sequential(nn.ReLU(), nn.Dropout(0.4)). Defaults to nn.GELU().

  • layout (IrrepsLayout, optional) – The layout of the input and output irreps. Default is cue.mul_ir which is the layout corresponding to e3nn.

  • use_fallback (bool, optional) – If None (default), a CUDA kernel will be used if available. If False, a CUDA kernel will be used, and an exception is raised if it’s not available. If True, a PyTorch fallback method is used regardless of CUDA kernel availability.

Examples

>>> in_irreps = cue.Irreps("O3", "4x0e + 4x1o")
>>> sh_irreps = cue.Irreps("O3", "0e + 1o")
>>> out_irreps = cue.Irreps("O3", "4x0e + 4x1o")

Case 1: MLP with the input layer having 6 channels and 2 hidden layers having 16 channels. edge_emb.size(1) must match the size of the input layer: 6

>>> conv1 = FullyConnectedTensorProductConv(in_irreps, sh_irreps, out_irreps,
...     mlp_channels=[6, 16, 16], mlp_activation=nn.ReLU(), layout=cue.ir_mul)
>>> conv1
FullyConnectedTensorProductConv(...)
>>> # out = conv1(src_features, edge_sh, edge_emb, graph)

Case 2: If edge_emb is constructed by concatenating scalar features from edges, sources and destinations, as in DiffDock, the layer can accept each scalar component separately:

>>> # out = conv1(src_features, edge_sh, edge_emb, graph, src_scalars, dst_scalars)

This allows a smaller GEMM in the first MLP layer by performing GEMM on each component before indexing. The first-layer weights are split into sections for edges, sources and destinations, in that order. This is equivalent to

>>> # src, dst = graph.edge_index
>>> # edge_emb = torch.hstack((edge_scalars, src_scalars[src], dst_scalars[dst]))
>>> # out = conv1(src_features, edge_sh, edge_emb, graph)

Case 3: No MLP, edge_emb will be directly used as the tensor product weights:

>>> conv3 = FullyConnectedTensorProductConv(in_irreps, sh_irreps, out_irreps,
...     mlp_channels=None, layout=cue.ir_mul)
>>> # out = conv3(src_features, edge_sh, edge_emb, graph)

Forward Pass

forward(
src_features: Tensor,
edge_sh: Tensor,
edge_emb: Tensor,
graph: tuple[Tensor, tuple[int, int]],
src_scalars: Tensor | None = None,
dst_scalars: Tensor | None = None,
reduce: str = 'mean',
edge_envelope: Tensor | None = None,
) Tensor#

Forward pass.

Parameters:
  • src_features (torch.Tensor) – Source node features. Shape: (num_src_nodes, in_irreps.dim)

  • edge_sh (torch.Tensor) – The spherical harmonic representations of the edge vectors. Shape: (num_edges, sh_irreps.dim)

  • edge_emb (torch.Tensor) –

    Edge embeddings that are fed into MLPs to generate tensor product weights. Shape: (num_edges, dim), where dim should be:

    • tp.weight_numel when the layer does not contain MLPs.

    • num_edge_scalars, when scalar features from edges, sources and destinations are passed in separately.

  • graph (tuple) – A tuple that stores the graph information, with the first element being the adjacency matrix in COO, and the second element being its shape: (num_src_nodes, num_dst_nodes)

  • src_scalars (torch.Tensor, optional) – Scalar features of source nodes. See examples for usage. Shape: (num_src_nodes, num_src_scalars)

  • dst_scalars (torch.Tensor, optional) – Scalar features of destination nodes. See examples for usage. Shape: (num_dst_nodes, num_dst_scalars)

  • reduce (str, optional) – Reduction operator. Choose between “mean” and “sum”. Defaults to “mean”.

  • edge_envelope (torch.Tensor, optional) – Typically used as attenuation factors to fade out messages coming from nodes close to the cutoff distance used to create the graph. This is important to make the model smooth to the changes in node’s coordinates. Shape: (num_edges,)

Returns:

Output node features. Shape: (num_dst_nodes, out_irreps.dim)

Return type:

torch.Tensor