MLP# class cuequivariance_jax.nnx.MLP# __init__( layer_sizes, activation, output_activation=False, *, precision=None, dtype=<class 'jax.numpy.float32'>, rngs, )# Parameters: layer_sizes (list[int]) activation (Callable) output_activation (bool) precision (Precision | None) dtype (Any) rngs (Rngs)