Operation#

class cuequivariance.Operation(
buffers: tuple[int, ...] | Operation,
)#

Descriptor mapping input/output buffers to tensor product operands.

The buffers are identified by their index (0, 1, 2, …). The order of the buffers corresponds to the order of the operands.

Example

This list of operations would typically be used for the symmetric contraction operation.

>>> ops = [
...     Operation((0, 1, 2)),
...     Operation((0, 1, 1, 2)),
...     Operation((0, 1, 1, 1, 2)),
... ]
>>> print(Operation.list_to_string(ops, 2, 1))
(a, b) -> (C)
  a b C
  a b b C
  a b b b C
transpose(
is_undefined_primal: list[bool],
has_cotangent: list[bool],
) Operation | None#
Parameters:
  • is_undefined_primal (list[bool]) – whether the primal is undefined

  • has_cotangent (list[bool]) – whether the cotangent is defined

Returns:

the transposed operation, if any
in the returned operation, the buffers are:
  • new inputs: defined primals + cotangents (=True)

  • new outputs: undefined primals

Return type:

Operation

jvp(
has_tangent: list[bool],
) list[Operation]#
Parameters:

has_tangent (list[bool]) – whether the input has a tangent

Returns:

the JVPs of the operation
in the returned operations, the buffers are:
  • new inputs: original inputs + tangents (=True)

  • new outputs: original outputs

Return type:

list[Operation]

operands_with_identical_buffers() frozenset[frozenset[int]]#

Groups of operands sharing the same buffer.

static group_by_idential_buffers(
operations: list[Operation],
) list[tuple[frozenset[frozenset[int]], list[Operation]]]#
Parameters:
  • operations (list[Operation]) – the operations to group

  • num_inputs (int) – the number of input buffers

Returns:

Each tuple contains:
  • frozenset of frozensets of operands bound to identical buffers

  • list of operations

Return type:

list of tuples

static group_by_operational_symmetries(
symmetries: list[tuple[int, ...]],
operations: list[Operation],
) list[tuple[int, Operation]]#
Parameters:
Returns:

Each tuple contains:
  • multiplicity (int)

  • a representative operation

Return type:

list of tuples