cuda.tile.jax.cutile_call#
- cuda.tile.jax.cutile_call(grid, kernel, args)#
Launch a cuTile kernel from a JAX-traced graph.
- Parameters:
grid (tuple[int, ...]) – Tuple of up to 3 grid dimensions to execute the kernel over. Padded with 1s on the right.
kernel – The kernel to execute.
args (tuple[Any, ...]) –
Positional arguments to pass to the kernel. Each entry must match the kernel’s corresponding parameter:
jax.Array: read-only input buffer.OutputPlaceholder: output buffer with the givenshapeanddtype; allocated by JAX and returned from this call.InputOutput: input buffer aliased to an output slot, enabling in-place updates.bool,int, orfloat: A scalar argument to the kernel. Because JAX treats scalar as 0D Array, to pass a scalar argument, it must be a static argument of the JAX function usingstatic_argnumsorstatic_argnames.
- Returns:
For a kernel with one output, the output array. For multiple outputs (multiple
OutputPlaceholder/InputOutputargs), a tuple of arrays in declaration order.
Notes
1. Array passed to cutile_call will use default XLA row-major order. Customizing layout will be supported in the future.
Example:
@ct.kernel def scale(x, y, c: ct.Constant, TILE_SIZE: ct.Constant): bid = ct.bid(0) tx = ct.load(x, bid, TILE_SIZE) * c ct.store(y, bid, tx) @ct.kernel def inplace(x, TILE_SIZE: ct.Constant): bid = ct.bid(0) tx = ct.load(x, bid, TILE_SIZE) ct.store(x, bid, tx / 2) @ct.kernel def sincos(x, y1, y2, TILE_SIZE: ct.Constant): bid = ct.bid(0) tx = ct.load(x, bid, TILE_SIZE) ct.store(y1, bid, ct.sin(tx)) ct.store(y2, bid, ct.cos(tx)) @jax.jit(static_argnums=[1, 2]) def graph(x, c, tile_size): grid = (ct.cdiv(x.shape[0], tile_size),) ph = OutputPlaceholder(x.shape, x.dtype) y = cutile_call(grid, scale, (x, ph, c, tile_size)) # inplace update y = cutile_call(grid, inplace, (InputOutput(y), tile_size)) # multiple outputs ysin, ycos = cutile_call(grid, sincos, (y, ph, ph, tile_size)) return ysin + ycos x = jnp.arange(10, dtype=jnp.float32) y = graph(x, jnp.pi, 4) print(y)
import cuda.tile as ct from cuda.tile.jax import OutputPlaceholder, InputOutput, cutile_call import jax import jax.numpy as jnp import numpy as np jnp.set_printoptions(precision=1) @ct.kernel def scale(x, y, c: ct.Constant, TILE_SIZE: ct.Constant): bid = ct.bid(0) tx = ct.load(x, bid, TILE_SIZE) * c ct.store(y, bid, tx) @ct.kernel def inplace(x, TILE_SIZE: ct.Constant): bid = ct.bid(0) tx = ct.load(x, bid, TILE_SIZE) ct.store(x, bid, tx / 2) @ct.kernel def sincos(x, y1, y2, TILE_SIZE: ct.Constant): bid = ct.bid(0) tx = ct.load(x, bid, TILE_SIZE) ct.store(y1, bid, ct.sin(tx)) ct.store(y2, bid, ct.cos(tx)) @jax.jit(static_argnums=[1, 2]) def graph(x, c, tile_size): grid = (ct.cdiv(x.shape[0], tile_size),) ph = OutputPlaceholder(x.shape, x.dtype) y = cutile_call(grid, scale, (x, ph, c, tile_size)) # inplace update y = cutile_call(grid, inplace, (InputOutput(y), tile_size)) # multiple outputs ysin, ycos = cutile_call(grid, sincos, (y, ph, ph, tile_size)) return ysin + ycos x = jnp.arange(10, dtype=jnp.float32) y = graph(x, jnp.pi, 4) print(y)
Output
[ 1. 1. -1. -1. 1. 1. -1. -1. 1. 1.]