Repeats#
- class cuequivariance_jax.Repeats#
A class to represent a sequence of repeated elements.
Example
>>> a = Repeats(jnp.array([1, 0, 2]), 3) >>> jnp.repeat( ... jnp.array([0.1, 0.2, 0.3], dtype=jnp.float32), ... a.repeats, ... total_repeat_length=a.total_repeat_length, ... ) Array([0.1, 0.3, 0.3], dtype=float32)