Repeats#

class cuequivariance_jax.Repeats#

A class to represent a sequence of repeated elements.

Example

>>> a = Repeats(jnp.array([1, 0, 2]), 3)
>>> jnp.repeat(
...     jnp.array([0.1, 0.2, 0.3], dtype=jnp.float32),
...     a.repeats,
...     total_repeat_length=a.total_repeat_length,
... )
Array([0.1, 0.3, 0.3], dtype=float32)
__init__(repeats, total_repeat_length=None)#
Parameters:
  • repeats (Array)

  • total_repeat_length (int)

Return type:

None