MLP#

class cuequivariance_jax.nnx.MLP#
__init__(
layer_sizes,
activation,
output_activation=False,
*,
precision=None,
dtype=<class 'jax.numpy.float32'>,
rngs,
)#
Parameters: