EquivariantTensorProduct#

class cuequivariance.EquivariantTensorProduct(
d: SegmentedTensorProduct | Sequence[SegmentedTensorProduct],
operands: list[Rep],
symmetrize: bool = True,
)#

Descriptor of an equivariant tensor product. This class is a wrapper around a list of STP. While an STP is a single homogeneous polynomial without specification of the role of each operand, an ETP determines the role of each operand (input or output), the representation of each operand (irreps), and the layout of each operand (multiplicity first or irreducible representation first).

Requirements:
  • An ETP must contain at least one STP.

  • Each STP must have at least one operand (the output).

Examples

Input0

Input1

Input2

Output

Comment

STP0

x

x

x

x

common case, the number of operands is the same

STP1

x

x

x

some inputs are not used by all STPs

STP2

x

x

– “ –

STP3

x

– “ –

STP4

x

x

x x x

x

the last input is fed multiple times

Methods

permute_operands(
permutation: tuple[int, ...],
) EquivariantTensorProduct#

Permute the operands of the tensor product.

move_operand(
src: int,
dst: int,
) EquivariantTensorProduct#

Move an operand to a new position.

move_operand_first(
src: int,
) EquivariantTensorProduct#

Move an operand to the front.

move_operand_last(
src: int,
) EquivariantTensorProduct#

Move an operand to the back.

squeeze_modes(
modes: str | None = None,
) EquivariantTensorProduct#

Squeeze the modes.

consolidate_paths() EquivariantTensorProduct#

Consolidate the paths.

canonicalize_subscripts() EquivariantTensorProduct#

Canonicalize the subscripts.

flatten_modes(
modes: str,
*,
skip_zeros: bool = True,
force: bool = False,
) EquivariantTensorProduct#

Flatten modes.

all_same_segment_shape() bool#

Whether all the segments have the same shape.

flatten_coefficient_modes() EquivariantTensorProduct#

Flatten the coefficient modes.

flop_cost(batch_size: int) int#

Compute the number of flops of the tensor product.

memory_cost(
batch_sizes: tuple[int, ...],
itemsize: int | tuple[int, ...],
) int#

Compute the number of memory accesses of the tensor product.

backward(
input: int,
) tuple[EquivariantTensorProduct, tuple[int, ...]]#

The backward pass of the equivariant tensor product.

classmethod stack(
es: Sequence[EquivariantTensorProduct],
stacked: list[bool],
) EquivariantTensorProduct#

Stack multiple equivariant tensor products.

symmetrize_operands() EquivariantTensorProduct#

Symmetrize the operands of the ETP.

sort_indices_for_identical_operands() EquivariantTensorProduct#

Sort the indices for identical operands.