cuda.tile.jax.cutile_call#

cuda.tile.jax.cutile_call(grid, kernel, args)#

Launch a cuTile kernel from a JAX-traced graph.

Parameters:
  • grid (tuple[int, ...]) – Tuple of up to 3 grid dimensions to execute the kernel over. Padded with 1s on the right.

  • kernel – The kernel to execute.

  • args (tuple[Any, ...]) –

    Positional arguments to pass to the kernel. Each entry must match the kernel’s corresponding parameter:

    • jax.Array: read-only input buffer.

    • OutputPlaceholder: output buffer with the given shape and dtype; allocated by JAX and returned from this call.

    • InputOutput: input buffer aliased to an output slot, enabling in-place updates.

    • bool, int, or float: A scalar argument to the kernel. Because JAX treats scalar as 0D Array, to pass a scalar argument, it must be a static argument of the JAX function using static_argnums or static_argnames.

Returns:

For a kernel with one output, the output array. For multiple outputs (multiple OutputPlaceholder / InputOutput args), a tuple of arrays in declaration order.

Notes

1. Array passed to cutile_call will use default XLA row-major order. Customizing layout will be supported in the future.

Example:

@ct.kernel
def scale(x, y, c: ct.Constant, TILE_SIZE: ct.Constant):
    bid = ct.bid(0)
    tx = ct.load(x, bid, TILE_SIZE) * c
    ct.store(y, bid, tx)


@ct.kernel
def inplace(x, TILE_SIZE: ct.Constant):
    bid = ct.bid(0)
    tx = ct.load(x, bid, TILE_SIZE)
    ct.store(x, bid, tx / 2)


@ct.kernel
def sincos(x, y1, y2, TILE_SIZE: ct.Constant):
    bid = ct.bid(0)
    tx = ct.load(x, bid, TILE_SIZE)
    ct.store(y1, bid, ct.sin(tx))
    ct.store(y2, bid, ct.cos(tx))



@jax.jit(static_argnums=[1, 2])
def graph(x, c, tile_size):
    grid = (ct.cdiv(x.shape[0], tile_size),)
    ph = OutputPlaceholder(x.shape, x.dtype)

    y = cutile_call(grid, scale, (x, ph, c, tile_size))

    # inplace update
    y = cutile_call(grid, inplace, (InputOutput(y), tile_size))

    # multiple outputs
    ysin, ycos = cutile_call(grid, sincos, (y, ph, ph, tile_size))

    return ysin + ycos

x = jnp.arange(10, dtype=jnp.float32)
y = graph(x, jnp.pi, 4)
print(y)
import cuda.tile as ct
from cuda.tile.jax import OutputPlaceholder, InputOutput, cutile_call

import jax
import jax.numpy as jnp
import numpy as np

jnp.set_printoptions(precision=1)

@ct.kernel
def scale(x, y, c: ct.Constant, TILE_SIZE: ct.Constant):
    bid = ct.bid(0)
    tx = ct.load(x, bid, TILE_SIZE) * c
    ct.store(y, bid, tx)


@ct.kernel
def inplace(x, TILE_SIZE: ct.Constant):
    bid = ct.bid(0)
    tx = ct.load(x, bid, TILE_SIZE)
    ct.store(x, bid, tx / 2)


@ct.kernel
def sincos(x, y1, y2, TILE_SIZE: ct.Constant):
    bid = ct.bid(0)
    tx = ct.load(x, bid, TILE_SIZE)
    ct.store(y1, bid, ct.sin(tx))
    ct.store(y2, bid, ct.cos(tx))



@jax.jit(static_argnums=[1, 2])
def graph(x, c, tile_size):
    grid = (ct.cdiv(x.shape[0], tile_size),)
    ph = OutputPlaceholder(x.shape, x.dtype)

    y = cutile_call(grid, scale, (x, ph, c, tile_size))

    # inplace update
    y = cutile_call(grid, inplace, (InputOutput(y), tile_size))

    # multiple outputs
    ysin, ycos = cutile_call(grid, sincos, (y, ph, ph, tile_size))

    return ysin + ycos

x = jnp.arange(10, dtype=jnp.float32)
y = graph(x, jnp.pi, 4)
print(y)

Output

[ 1.  1. -1. -1.  1.  1. -1. -1.  1.  1.]