12. Appendix#
12.1. Data Type Support#
12.2. Programming Model Example Programs#
12.2.1. Hello Tile Block#
cuda_tile.module @hello_world_module {
entry @hello_world_kernel() {
print "Hello World!\n"
}
}
12.2.2. Vector Addition Block 128x1#
// A basic implementation of 128 sized vector addition using unstructured load/stores.
//
// This implements addition over a 1-d tensor (vector) with size 128.
//
// 128x1 + 128x1 => 128x1
cuda_tile.module @vector_block_add_128x1 {
entry @vector_block_add_128x1_kernel(
%a_ptr_base_scalar : !cuda_tile.tile<ptr<f32>>,
%b_ptr_base_scalar : !cuda_tile.tile<ptr<f32>>,
%c_ptr_base_scalar : !cuda_tile.tile<ptr<f32>>)
{
// Create an offset on the inclusive (0, 127) interval.
%offset = iota : tile<128xi32>
// We need a tile<ptr<T>> in order to perform a load or store.
//
// We will now convert each raw base pointer into such a pointer.
//
// First reshape the scalar pointer ptr<f32> to tile<1xptr<f32>> so it has the correct rank.
%a_ptr_base_tensor = reshape %a_ptr_base_scalar :
tile<ptr<f32>> -> tile<1xptr<f32>>
// Next broadcast the pointer so we have a tensor of (base, ..., base) containing 128 elements.
%a_ptr = broadcast %a_ptr_base_tensor : tile<1xptr<f32>> -> tile<128xptr<f32>>
// Finally add the offset tensor to the tensor of pointers to obtain a tile<128xptr<f32>> that contains
// pointers of (base + 0, ..., base + 127) as its values.
%a_tensor = offset %a_ptr, %offset :
tile<128xptr<f32>>, tile<128xi32> -> tile<128xptr<f32>>
// Now we do the same for B.
%b_ptr_base_tensor =reshape %b_ptr_base_scalar :
tile<ptr<f32>> -> tile<1xptr<f32>>
%b_ptr = broadcast %b_ptr_base_tensor : tile<1xptr<f32>> -> tile<128xptr<f32>>
%b_tensor = offset %b_ptr, %offset :
tile<128xptr<f32>>, tile<128xi32> -> tile<128xptr<f32>>
// And the same for C.
%c_ptr_base_tensor = reshape %c_ptr_base_scalar :
tile<ptr<f32>> -> tile<1xptr<f32>>
%c_ptr = broadcast %c_ptr_base_tensor : tile<1xptr<f32>> -> tile<128xptr<f32>>
%c_tensor = offset %c_ptr, %offset :
tile<128xptr<f32>>, tile<128xi32> -> tile<128xptr<f32>>
// Now that we have prepared all the pointers we can do the real work.
//
// First we load A, and B into %a_val and %b_val.
%a_val, %token_a = load_ptr_tko weak %a_tensor : tile<128xptr<f32>> -> tile<128xf32>, token
%b_val, %token_b = load_ptr_tko weak %b_tensor : tile<128xptr<f32>> -> tile<128xf32>, token
// We then compute floating-point vector addition using addf
%c_val = addf %a_val, %b_val rounding<nearest_even> : tile<128xf32>
// Finally we store the result to C.
store_ptr_tko weak %c_tensor, %c_val : tile<128xptr<f32>>, tile<128xf32> -> token
}
}
12.2.3. Hello Tile Grid#
cuda_tile.module @hello_world_module {
// TileIR kernel function
entry @hello_world_kernel() {
// Step 1. Get the tile block ID
%block_x_index, %block_y_index, %block_z_index = cuda_tile.get_tile_block_id : tile<i32>
// Step 2. Get the tile block dimensions
%block_dim_x, %block_dim_y, %block_dim_z = cuda_tile.get_num_tile_blocks : tile<i32>
// Step 3. Print the tile block ID and dimensions. Each tile executes the
// following print statement and prints a single line.
cuda_tile.print "Hello, I am tile <%, %, %> in a kernel with <%, %, %> tiles.\n",
%block_x_index, %block_y_index, %block_z_index, %block_dim_x, %block_dim_y, %block_dim_z
: tile<i32>, tile<i32>, tile<i32>,
tile<i32>, tile<i32>, tile<i32>
}
}
12.2.4. GEMM Single 64x64 Block#
// An implementation of GEMM for a single statically shaped square 64x64 block.
cuda_tile.module @gemm_block_64x64_module {
entry @gemm_block_64x64_kernel(
%a_ptr_base_scalar: !cuda_tile.tile<!cuda_tile.ptr<f32>>,
%b_ptr_base_scalar: !cuda_tile.tile<!cuda_tile.ptr<f32>>,
%c_ptr_base_scalar: !cuda_tile.tile<!cuda_tile.ptr<f32>>
) {
%offset_flat = iota : tile<4096xi32>
%offset = reshape %offset_flat :
tile<4096xi32> -> tile<64x64xi32>
// Can we have iota support producing tensors directly?
// %offset = iota : tile<64x64xi32>
%a_ptr_base_tensor = reshape %a_ptr_base_scalar :
tile<ptr<f32>> -> tile<1x1xptr<f32>>
%a_ptr = broadcast %a_ptr_base_tensor : tile<1x1xptr<f32>> -> tile<64x64xptr<f32>>
%a_tensor = offset %a_ptr, %offset :
tile<64x64xptr<f32>>, tile<64x64xi32> -> tile<64x64xptr<f32>>
// Now we do the same for B.
%b_ptr_base_tensor = reshape %b_ptr_base_scalar :
tile<ptr<f32>> -> tile<1x1xptr<f32>>
%b_ptr = broadcast %b_ptr_base_tensor : tile<1x1xptr<f32>> -> tile<64x64xptr<f32>>
%b_tensor = offset %b_ptr, %offset :
tile<64x64xptr<f32>>, tile<64x64xi32> -> tile<64x64xptr<f32>>
// And the same for C.
%c_ptr_base_tensor = reshape %c_ptr_base_scalar :
tile<ptr<f32>> -> tile<1x1xptr<f32>>
%c_ptr = broadcast %c_ptr_base_tensor : tile<1x1xptr<f32>> -> tile<64x64xptr<f32>>
%c_tensor = offset %c_ptr, %offset :
tile<64x64xptr<f32>>, tile<64x64xi32> -> tile<64x64xptr<f32>>
// Load a single 64x64 matrix from the tile.
%A_block, %token_a = load_ptr_tko weak %a_tensor :
tile<64x64xptr<f32>> -> tile<64x64xf32>, token
// Load a single 64x64 matrix from the tile.
%B_block, %token_b = load_ptr_tko weak %b_tensor :
tile<64x64xptr<f32>> -> tile<64x64xf32>, token
%init_accum = cuda_tile.constant <f32: 0.000000e+00> : !cuda_tile.tile<64x64xf32>
// WHy did this type check? %C_block = cuda_tile.dot %A_frag, %B_frag, %init_accum: tile<64x64xf16>, tile<64x64xf16>, tile<64x64xf32>
//
// I feel like we should seriously reconsider naming this `dot` it is super confusing because it doesn't actually implement true dot-product.
%C_block = mmaf %A_block, %B_block, %init_accum: tile<64x64xf32>, tile<64x64xf32>, tile<64x64xf32>
store_ptr_tko weak %c_tensor, %C_block :
tile<64x64xptr<f32>>, tile<64x64xf32> -> token
}
}
12.2.5. GEMM 4096x4096 Block#
// An implementation of GEMM for a square 4096x4096 @ 4096x4096 multiplication.
//
// This would be launched on a 64x64 grid where each tile block computes 64x64 output
// tiles of C.
cuda_tile.module @gemm_square_4096_tile_64x64_module {
entry @gemm_square_4096_tile_64x64_kernel(
%a_ptr_base_scalar: tile<ptr<f32>>,
%b_ptr_base_scalar: tile<ptr<f32>>,
%c_ptr_base_scalar: tile<ptr<f32>>
) {
// We first setup up some state for the kernel.
//
// Keep this line of space for the spec right now.
// Read Tile block id's.
%block_x_index, %block_y_index, %block_z_index = get_tile_block_id : tile<i32>
// We assume we have tiled a 4096x4096 @ 4096x4096 matrix mupltiplication split into
// // 64x64 tiles so the tile m, n, k are all 64.
// %stride_A_m = constant <i32: 64> : tile<i32>
// %stride_ = constant <i32: 64> : tile<i32>
// %k_tile_size = constant <i32: 64> : tile<i32>
// We assume we have tiled a 4096x4096 @ 4096x4096 matrix mupltiplication split into
// 64x64 tiles so the tile m, n, k are all 64.
%m_tile_size = constant <i32: 64> : tile<i32>
%m_stride_factor = cuda_tile.constant <i32: 64> : tile<64x64xi32> // todo fix the line # and restore this %n_tile_size = cuda_tile.constant <i32: 64> : tile<i32>
%k_tile_size = cuda_tile.constant <i32: 64> : tile<i32>
%range_start = cuda_tile.constant <i32: 0> : tile<i32>
%range_step = cuda_tile.constant <i32: 1> : tile<i32>
%init_accum = cuda_tile.constant <f32: 0.000000e+00> : tile<64x64xf32>
// The shared range from (0, 63).
%tile_size_range = cuda_tile.iota : tile<64xi32>
// We must first compute the tensors of initial offsets for A and B so that we can obtain a tensor of
// pointers for each to load them.
//
// First we compute the starting indices of A's tile in this case this will be the "top-corner"
// of a row-major tile specified by block_x_index.
//
// The only way to contruct the offset matrix for a tile is by building up a tensor step by step.
//
// Conceptually we start by computing the offsets of the M dimension:
//
// m_offsets = block_x_index * k_tile_size + arange(0, k_tile_size)
//
// This produces a vector starting from the "top-corner" of the tile at (block_x_index * tile_size, block_x_index * tile_size + tile_size).
%a_tile_base = cuda_tile.muli %block_x_index, %k_tile_size : tile<i32>
%a_tile_base_reshape = cuda_tile.reshape %a_tile_base : tile<i32> -> tile<1xi32>
%a_tile_base_tensor = cuda_tile.broadcast %a_tile_base_reshape :
tile<1xi32> -> tile<64xi32>
%m_offsets_vec = cuda_tile.addi %a_tile_base_tensor, %tile_size_range : tile<64xi32>
// The striding of the A matrix is (64, 1) meaning that we don't need to do anything special for the K
// dimension we just want each row of the offset matrix to be sequential.
//
// We can reuse %tile_size_range as k_offs = torch.arange(0, k_tile_size).
//
// a_tile = reshape(m_offs, (64, 1)) * m_stride + reshape(k_offs, (1, 64)) * k_stride
//
// We first broadcast the m_offsets into a matrix where each column is identical and scaled by stride.
%m_offsets_matrix = cuda_tile.reshape %m_offsets_vec :
tile<64xi32> -> tile<64x1xi32>
%m_offsets_broadcast = cuda_tile.broadcast %m_offsets_matrix :
tile<64x1xi32> -> tile<64x64xi32>
%m_offsets = cuda_tile.muli %m_offsets_broadcast, %m_stride_factor : tile<64x64xi32>
// We then broadcast the k_offsets into a matrix where row is identical and scaled by stride.
%ak_offsets_matrix = cuda_tile.reshape %tile_size_range :
tile<64xi32> -> tile<1x64xi32>
%ak_offsets_broadcast = cuda_tile.broadcast %ak_offsets_matrix :
tile<1x64xi32> -> tile<64x64xi32>
%ak_offsets = cuda_tile.muli %ak_offsets_broadcast, %m_stride_factor : tile<64x64xi32>
// Finally we add them together resulting in the final offset matrix for A.
%a_tile_offsets = cuda_tile.addi %m_offsets, %ak_offsets : tile<64x64xi32>
// Now we do the same for B, first prepare the set of n_offsets.
// n_offs = j * n_tile_size + torch.arange(0, n_tile_size)
%b_tile_base = cuda_tile.muli %block_y_index, %k_tile_size : tile<i32>
%b_tile_base_reshape = cuda_tile.reshape %b_tile_base :
tile<i32> -> tile<1xi32>
%b_tile_base_tensor = cuda_tile.broadcast %b_tile_base_reshape :
tile<1xi32> -> tile<64xi32>
%n_offsets_vec = cuda_tile.addi %b_tile_base_tensor, %tile_size_range : tile<64xi32>
// b_tile = k_offs[:, None] * k_stride + n_offs[None, :] * n_stride
%bk_offsets_matrix = cuda_tile.reshape %tile_size_range : tile<64xi32> -> tile<64x1xi32>
// Stride is one.
%bk_offsets = cuda_tile.broadcast %bk_offsets_matrix : tile<64x1xi32> -> tile<64x64xi32>
// %bk_offsets = cuda_tile.muli %bk_offsets_broadcast, %k_tile_size : tile<i32>
%n_offsets_matrix = cuda_tile.reshape %n_offsets_vec : tile<64xi32> -> tile<1x64xi32>
%n_offsets_broadcast = cuda_tile.broadcast %n_offsets_matrix : tile<1x64xi32> -> tile<64x64xi32>
%n_offsets = cuda_tile.muli %n_offsets_broadcast, %m_stride_factor : tile<64x64xi32>
%b_tile_offsets = cuda_tile.muli %bk_offsets, %n_offsets : tile<64x64xi32>
// Now the rest of the kernel looks like what we did before we simply convert the base pointer
// to a tensor, add the offset matrix, and continue.
%a_ptr_base_tensor = cuda_tile.reshape %a_ptr_base_scalar :
tile<ptr<f32>> -> tile<1x1xptr<f32>>
%a_ptr = cuda_tile.broadcast %a_ptr_base_tensor : tile<1x1xptr<f32>> -> tile<64x64xptr<f32>>
%a_tile_ptr = offset %a_ptr, %a_tile_offsets :
tile<64x64xptr<f32>>, tile<64x64xi32> -> tile<64x64xptr<f32>>
// And the same for B.
%b_ptr_tile_tensor = reshape %b_ptr_base_scalar :
tile<ptr<f32>> -> tile<1x1xptr<f32>>
%b_ptr = broadcast %b_ptr_tile_tensor : tile<1x1xptr<f32>> -> tile<64x64xptr<f32>>
%b_tile_ptr = offset %b_ptr, %b_tile_offsets :
tile<64x64xptr<f32>>, tile<64x64xi32> -> tile<64x64xptr<f32>>
%C_tile, %a_ptr_final, %b_ptr_final = for %k in (%range_start to %k_tile_size, step %range_start) : tile<i32>
iter_values(
%acc_prev = %init_accum,
%a_tile_ptr_prev = %a_tile_ptr,
%b_tile_ptr_prev = %b_tile_ptr
) -> (tile<64x64xf32>, tile<64x64xptr<f32>>, tile<64x64xptr<f32>>)
{
// Load a single 64x64 matrix from the tile.
%A_tile, %token_a = load_ptr_tko weak %a_tile_ptr :
tile<64x64xptr<f32>> -> tile<64x64xf32>, token
// Load a single 64x64 matrix from the tile.
%B_tile, %token_b = load_ptr_tko weak %b_tile_ptr :
tile<64x64xptr<f32>> -> tile<64x64xf32>, token
%C_tile_acc = mmaf %A_tile, %B_tile, %acc_prev: tile<64x64xf32>, tile<64x64xf32>, tile<64x64xf32>
// Advance by K block size.
%block_size = constant <i32: 64> : tile<64x64xi32>
%a_tile_ptr_next = offset %a_tile_ptr_prev, %block_size
: tile<64x64xptr<f32>>, tile<64x64xi32>
-> tile<64x64xptr<f32>>
%b_tile_ptr_next = offset %b_tile_ptr_prev, %block_size
: tile<64x64xptr<f32>>, tile<64x64xi32>
-> tile<64x64xptr<f32>>
// Store the partial sum to the 64x64 accumulator.
continue %C_tile_acc, %a_tile_ptr_next, %b_tile_ptr_next : tile<64x64xf32>, tile<64x64xptr<f32>>, tile<64x64xptr<f32>>
}
// We now need to do the thing for the offsets in C, but not inside the loop.
//
// The equivalent Triton code for this computation is:
//
// offs_cm = block_x_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
//
// We first start by computing the offset at which the tile starts on the x coordinate of the matrix.
%c_tile_x_start = muli %block_x_index, %k_tile_size :
tile<i32>
%c_tile_x_start_reshape = reshape %c_tile_x_start :
tile<i32> -> tile<1xi32>
%c_tile_x_start_tensor = broadcast %c_tile_x_start_reshape :
tile<1xi32> -> tile<64xi32>
// We now have a vector which goes from (block_x_index * tile_size_m, block_x_index * tile_size_m + 63)
%c_tile_x_offsets_vec = addi %c_tile_x_start_tensor, %tile_size_range : tile<64xi32>
// We do the same for this computation:
//
// offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
%c_tile_y_start = muli %block_x_index, %k_tile_size : tile<i32>
%c_tile_y_start_reshape = reshape %c_tile_y_start : tile<i32> -> tile<1xi32>
%c_tile_y_start_tensor = broadcast %c_tile_y_start_reshape :
tile<1xi32> -> tile<64xi32>
%c_tile_y_offsets_vec = addi %c_tile_y_start_tensor, %tile_size_range : tile<64xi32>
// We now want to do broadcating addition to get the file tensor of offsets which represent the tile.
//
// c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
//
// We first prepare: stride_cm * offs_cm[:, None]
%c_tile_x_offsets_matrix = reshape %c_tile_x_offsets_vec : tile<64xi32> -> tile<64x1xi32>
%c_tile_x_offsets_broadcast = broadcast %c_tile_x_offsets_matrix : tile<64x1xi32> -> tile<64x64xi32>
%c_tile_x_offsets = muli %c_tile_x_offsets_broadcast, %m_stride_factor : tile<64x64xi32>
// We then prepare: stride_cn * offs_cn[None, :]
%c_tile_y_offsets_matrix = reshape %c_tile_y_offsets_vec : tile<64xi32> -> tile<1x64xi32>
%c_tile_y_offsets_broadcast = broadcast %c_tile_y_offsets_matrix : tile<1x64xi32> -> tile<64x64xi32>
// Finally we add the X and the Y coordinates together to get a complete matrix.
%c_tile_y_offsets = muli %c_tile_y_offsets_broadcast, %m_stride_factor : tile<64x64xi32>
%c_tile_offsets = muli %c_tile_x_offsets, %c_tile_y_offsets : tile<64x64xi32>
// And the same for C.
%c_ptr_base_tensor = reshape %c_ptr_base_scalar :
tile<ptr<f32>> -> tile<1x1xptr<f32>>
%c_ptr = broadcast %c_ptr_base_tensor :
tile<1x1xptr<f32>> -> tile<64x64xptr<f32>>
%c_tile_ptr = offset %c_ptr, %c_tile_offsets :
tile<64x64xptr<f32>>, tile<64x64xi32> -> tile<64x64xptr<f32>>
store_ptr_tko weak %c_tile_ptr, %C_tile :
tile<64x64xptr<f32>>, tile<64x64xf32> -> token
}
}
12.2.6. Vector Addition with tensor_view#
// Tiled SAXPY is an optimized implementation of the SAXPY operation.
// This kernel uses memref abstractions for data load and store operations that allow structured load and store and can map accelerator memory engines in our hardware.
// The program divides X and Y into smaller tiles to enable parallelism on multiple tiles.
// Each Tile Block computes a tile of X and Y and stores the result back.
// This example can also be added in the blog post
// "Six Ways to SAXPY": https://developer.nvidia.com/blog/six-ways-saxpy/
cuda_tile.module @saxpy {
// TileIR kernel function
entry @saxpy_memref(%X: tile<ptr<f32>>,
%Y: tile<ptr<f32>>,
%alpha: tile<f32>,
%M : tile<i32>,
%N : tile<i32>) {
// Step 1. Get the tile block ID
%tileIdX, %tileIdY, %tileIdZ = get_tile_block_id : tile<i32>
// Step 2. Reshape and broadcast the alpha scalar
%alpha_reshaped = reshape %alpha : tile<f32> -> tile<1x1xf32>
%alpha_tensor = broadcast %alpha_reshaped : tile<1x1xf32> -> tile<128x256xf32>
// Step 3. Create tensor_view for X and Y
%x_memref = make_tensor_view %X, shape = [%M, %N], strides = [%M, 1] : tile<i32> -> tensor_view<?x?xf32, strides=[?,1]>
%y_memref = make_tensor_view %Y, shape = [%M, %N], strides = [%M, 1] : tile<i32> -> tensor_view<?x?xf32, strides=[?,1]>
// Step 4. Create partition view for X and Y
%x_view = make_partition_view %x_memref : partition_view<tile=(128x256), tensor_view<?x?xf32, strides=[?,1]>>
%y_view = make_partition_view %y_memref : partition_view<tile=(128x256), tensor_view<?x?xf32, strides=[?,1]>>
// Step 5. Load tile from X and Y
%x_tile, %token_x = load_view_tko weak %x_view[%tileIdX, %tileIdY] :
partition_view<tile=(128x256), tensor_view<?x?xf32, strides=[?,1]>>, tile<i32> -> tile<128x256xf32>, token
%y_tile, %token_y = load_view_tko weak %y_view[%tileIdX, %tileIdY] :
partition_view<tile=(128x256), tensor_view<?x?xf32, strides=[?,1]>>, tile<i32> -> tile<128x256xf32>, token
// Step 6. Compute sAXPY: y = alpha * A + y
%9 = mulf %alpha_tensor, %x_tile rounding<nearest_even> : tile<128x256xf32>
%result_tile = addf %9, %y_tile rounding<nearest_even> : tile<128x256xf32>
// Step 7. Store the result tile to Y
store_view_tko weak %result_tile, %y_view[%tileIdX, %tileIdY] :
tile<128x256xf32>, partition_view<tile=(128x256), tensor_view<?x?xf32, strides=[?,1]>>, tile<i32> -> token
}
}
12.2.7. GEMM Tiled with tensor_view#
// An implementation of GEMM in cuda_tile.
//
// Kernel computes MxNxK with 128x128x64 Tile Size.
// Computes F32 += f16 * f16 + 0.0
//
// This implementation does tiling, and reduction over
// K for dynamic sizes.
cuda_tile.module @gemm_kloop_module {
entry @gemm_kloop_kernel(
%A_ptr: !cuda_tile.tile<!cuda_tile.ptr<f16>>,
%B_ptr: !cuda_tile.tile<!cuda_tile.ptr<f16>>,
%C_ptr: !cuda_tile.tile<!cuda_tile.ptr<f32>>,
%M: !cuda_tile.tile<i32>, %N: !cuda_tile.tile<i32>, %K: !cuda_tile.tile<i32>,
%stride_ak: !cuda_tile.tile<i32>, %stride_bn: !cuda_tile.tile<i32>, %stride_cm: !cuda_tile.tile<i32>
) {
// First we need to prepare the inputs for the actual computation.
//
// Assume the preconditions of this kernel (i.e., the stride are all divisible by 8)
%A_ptr_assume = assume #cuda_tile.div_by<16>, %A_ptr : tile<ptr<f16>>
%B_ptr_assume = assume #cuda_tile.div_by<16>, %B_ptr : tile<ptr<f16>>
%C_ptr_assume = assume #cuda_tile.div_by<16>, %C_ptr : tile<ptr<f32>>
%stride_ak_assume = assume #cuda_tile.div_by<8>, %stride_ak : tile<i32>
%stride_bn_assume = assume #cuda_tile.div_by<8>, %stride_bn : tile<i32>
%stride_cm_assume = assume #cuda_tile.div_by<8>, %stride_cm : tile<i32>
// Constants must be allocated explicitly in the program, below we allocate scalar `0`, `1`,
// and the zero'd tensor used for accumulation.
%i0 = constant <i32: 0> : !cuda_tile.tile<i32>
%i1 = constant <i32: 1> : !cuda_tile.tile<i32>
%cst = constant <f32: 0.000000e+00> : !cuda_tile.tile<128x128xf32>
// Convert the unstructured pointers `ptr` to `tensor_view`.
//
// A reference to the A tensor pointed to by A_ptr, (K x M)
%A = make_tensor_view %A_ptr_assume, shape = [%K, %M], strides = [%stride_ak, 1] : tile<i32> -> tensor_view<?x?xf16, strides=[?,1]>
// A reference to the B tensor pointed to by B_ptr, (N x K)
%B = make_tensor_view %B_ptr_assume, shape = [%N, %K], strides = [%stride_bn, 1] : tile<i32> -> tensor_view<?x?xf16, strides=[?,1]>
// A reference to the C tensor pointed to by C_ptr, (M x N)
%C = make_tensor_view %C_ptr_assume, shape = [%M, %N], strides = [%stride_cm, 1] : tile<i32> -> tensor_view<?x?xf32, strides=[?,1]>
// Now we have all the inputs as structured pointers each associated with layouts.
//
// Next we will tile the problem.
//
// Our matrix multiplication is (M*K) @ (K*N) = M*N but our input tensors are transposed.
//
// In order to handle this we create partition view where we flip the 0th and 1st dims.
// We are blocking A (K x M) -> block_m x block_k.
%A_block = make_partition_view %A : partition_view<tile=(128x64), tensor_view<?x?xf16, strides=[?,1]>, dim_map=[1, 0]>
// We are blocking B (N x K) -> block_k x block_n.
%B_block = make_partition_view %B : partition_view<tile=(64x128), tensor_view<?x?xf16, strides=[?,1]>, dim_map=[1, 0]>
// We are blocking C (M xN) -> block_m x block_n.
%C_block = make_partition_view %C : partition_view<tile=(128x128), tensor_view<?x?xf32, strides=[?,1]>, dim_map=[0, 1]>
// Read Tile block id's.
%bidx, %bidy, %bidz = get_tile_block_id : tile<i32>
// Because we allow for dynamic dimensions we must get the reduction dimension `K` dynamically.
%mk_len_i32:2 = get_index_space_shape %A_block : partition_view<tile=(128x64), tensor_view<?x?xf16, strides=[?,1]>, dim_map=[1, 0]> -> tile<i32>
// Now that we have done all the setup, we can finally perform the computation itself.
//
// We simply loop over the K dimension computing: dot(A_block[0, k], B_block[k, 0]).
%result = for %k in (%i0 to %mk_len_i32#1, step %i1) : tile<i32>
iter_values(%acc_prev = %cst) -> (tile<128x128xf32>)
{
// Load a single 128x64 matrix from the tile.
%A_frag, %t1 = load_view_tko weak %A_block[%bidx, %k] : partition_view<tile=(128x64), tensor_view<?x?xf16, strides=[?,1]>, dim_map=[1, 0]>, tile<i32> -> tile<128x64xf16>, token
// Load a single 64x128 matrix from the tile.
%B_frag, %t2 = load_view_tko weak %B_block [%k, %bidy] : partition_view<tile=(64x128), tensor_view<?x?xf16, strides=[?,1]>, dim_map=[1, 0]>, tile<i32> -> tile<64x128xf16>, token
// Compute the mma(A_frag, B_frag) + acc_prev.
%acc = mmaf %A_frag, %B_frag, %acc_prev: tile<128x64xf16>, tile<64x128xf16>, tile<128x128xf32>
// Store the partial sum to the 128x128 accumulator.
continue %acc : tile<128x128xf32>
}
// Finally store the complete 128x128 tile to the view of C.
%t3 = store_view_tko weak %result, %C_block[%bidx, %bidy] : tile<128x128xf32>, partition_view<tile=(128x128), tensor_view<?x?xf32, strides=[?,1]>, dim_map=[0, 1]>, tile<i32> -> token
}
}
12.3. Operation Examples#
12.3.1. cuda_tile.cat_0#
cuda_tile.module @module {
entry @example() {
%arg0 = constant <f32: 0.0> : tile<2x4xf32>
%arg1 = constant <f32: 1.0> : tile<2x4xf32>
// A valid invocation of cat.
%0 = cat %arg0, %arg1 dim = 1
: tile<2x4xf32>, tile<2x4xf32> -> tile<2x8xf32>
// >>> %arg0 = tile([[ A, B, C ],
// [ D, E, F ]])
// >>> %arg1 = tile([[ 1, 2, 3 ],
// [ 4, 5, 6 ]])
// >>> %0 = tile([[ A, B, C, 1, 2, 3 ],
// [ D, E, F, 4, 5, 6 ]])
// A valid invocation of cat.
%1 = cat %arg0, %arg1 dim = 0
: tile<2x4xf32>, tile<2x4xf32> -> tile<4x4xf32>
// >>> %arg0 = tile([[ A, B, C ],
// [ D, E, F ]])
//
// >>> %arg1 = tile([[ 1, 2, 3 ],
// [ 4, 5, 6 ]])
//
// >>> %1 = tile([[ A, B, C ],
// [ D, E, F ],
// [ 1, 2, 3 ],
// [ 4, 5, 6 ]])
}
}
12.3.2. cuda_tile.cmpf_0#
cuda_tile.module @ex_module {
entry @example() {
%lhs0 = constant <f16: 0.0> : tile<f16>
%rhs0 = constant <f16: 0.0> : tile<f16>
// Custom form of scalar "ordered equal" comparison.
%x0 = cmpf equal ordered %lhs0, %rhs0 : tile<f16> -> tile<i1>
%lhs1 = constant <f16: 0.0> : tile<2x2xf16>
%rhs1 = constant <f16: 0.0> : tile<2x2xf16>
// Custom form of scalar "unordered less than" comparison.
%x2 = cmpf less_than unordered %lhs1, %rhs1 : tile<2x2xf16> -> tile<2x2xi1>
%lhs2 = constant <f64: 0.0> : tile<2x2xf64>
%rhs2 = constant <f64: 0.0> : tile<2x2xf64>
}
}
12.3.3. cuda_tile.cmpi_0#
cuda_tile.module @module {
entry @example() {
%lhs0 = constant <i16: 0> : tile<i16>
%rhs0 = constant <i16: 0> : tile<i16>
// Scalar "signed less than" comparison.
%x0 = cmpi less_than %lhs0, %rhs0, signed : tile<i16> -> tile<i1>
%lhs1 = constant <i64: 0> : tile<2x2xi64>
%rhs1 = constant <i64: 0> : tile<2x2xi64>
// Tile equality comparison.
// There is no difference between "signed" and "unsigned" when performing equality and inequality comparison.
%x1 = cmpi equal %lhs1, %rhs1, signed : tile<2x2xi64> -> tile<2x2xi1>
}
}
12.3.4. cuda_tile.constant_0#
cuda_tile.module @module {
entry @example() {
%c0 = constant <i32: 0> : tile<i32>
%c1 = constant <i64: 1> : tile<i64>
%c2 = constant <i32: [0, 1, 2, 3]> : tile<4xi32>
%c3 = constant <f32: 0.0> : tile<2x4xf32>
%c4 = constant <f64: [0.0, 1.0, 2.0, 3.0]> : tile<4xf64>
}
}
12.3.5. cuda_tile.extract_0#
cuda_tile.module @module {
entry @example() {
// Extract a subtile from %t at dim_0 = [4;8) and dim_1 = [4;6).
%c1 = constant <i32: 1> : tile<i32>
%c2 = constant <i32: 2> : tile<i32>
%t = constant <f32: 0.0> : tile<32x8xf32>
// Valid indices are: [ {0, 1, 2, 3, 4, 5, 6, 7}, {0, 1, 2, 3} ]
%0 = extract %t[%c1, %c2]
: tile<32x8xf32> -> tile<4x2xf32>
}
}
12.3.6. cuda_tile.get_global_0#
cuda_tile.module @module {
global @val <f32: [0.1, 0.2, 0.3, 0.4]> : tile<4xf32>
entry @example() {
%ptr = get_global @val : tile<ptr<f32>>
return
}
}
12.3.7. cuda_tile.get_num_tile_blocks_0#
cuda_tile.module @module {
entry @example() {
%x, %y, %z = get_num_tile_blocks : tile<i32>
// print "x: %, y: %, z: %\n", %x, %y, %z : tile<i32>, tile<i32>, tile<i32>
}
}
12.3.8. cuda_tile.global_0#
cuda_tile.module @module {
global @val alignment = 128 <f32: [0.1, 0.2, 0.3, 0.4]> : tile<4xf32>
entry @example() {}
}
12.3.9. cuda_tile.mmaf_0#
cuda_tile.module @module {
entry @example() {
%lhs0 = constant <f16: 0.0> : tile<4x8xf16>
%rhs0 = constant <f16: 0.0> : tile<8x2xf16>
%acc0 = constant <f32: 0.0> : tile<4x2xf32>
%0 = mmaf %lhs0, %rhs0, %acc0
: tile<4x8xf16>, tile<8x2xf16>,
tile<4x2xf32>
%lhs1 = constant <f16: 0.0> : tile<2x4x8xf16>
%rhs1 = constant <f16: 0.0> : tile<2x8x2xf16>
%acc1 = constant <f32: 0.0> : tile<2x4x2xf32>
%1 = mmaf %lhs1, %rhs1, %acc1
: tile<2x4x8xf16>, tile<2x8x2xf16>,
tile<2x4x2xf32>
}
}
12.3.10. cuda_tile.mmai_0#
cuda_tile.module @module {
entry @example() {
%lhs0 = cuda_tile.constant <i8: 0> : tile<4x8xi8>
%rhs0 = cuda_tile.constant <i8: 0> : tile<8x2xi8>
%acc0 = cuda_tile.constant <i32: 0> : tile<4x2xi32>
%0 = mmai %lhs0, %rhs0, %acc0 signed signed
: tile<4x8xi8>, tile<8x2xi8>,
tile<4x2xi32>
%lhs1 = cuda_tile.constant <i8: 0> : tile<2x4x8xi8>
%rhs1 = cuda_tile.constant <i8: 0> : tile<2x8x2xi8>
%acc1 = cuda_tile.constant <i32: 0> : tile<2x4x2xi32>
%1 = mmai %lhs1, %rhs1, %acc1 unsigned unsigned
: tile<2x4x8xi8>, tile<2x8x2xi8>,
tile<2x4x2xi32>
}
}
12.3.11. cuda_tile.pack_0#
cuda_tile.module @module {
entry @example() {
%arg0 = constant <f16: 0.0> : tile<64xf16>
%0 = pack %arg0 : tile<64xf16> -> tile<128xi8>
}
}
12.3.12. cuda_tile.pack_1#
cuda_tile.module @module {
entry @example() {
%arg0 = constant <f4E2M1FN: 0.0> : tile<64xf4E2M1FN>
%0 = pack %arg0 : tile<64xf4E2M1FN> -> tile<32xi8>
}
}
12.3.13. cuda_tile.permute_0#
cuda_tile.module @module {
entry @example() {
%arg0 = constant <f16: 0.0> : tile<2x4x8xf16>
%0 = permute %arg0 [2, 0, 1] : tile<2x4x8xf16> -> tile<8x2x4xf16>
}
}
12.3.14. cuda_tile.reduce_0#
cuda_tile.module @module {
entry @example() {
%input = constant <f32: 0.0> : tile<8xf32>
%0 = reduce %input dim=0 identities=[0.000000e+0 : f32] : tile<8xf32> -> tile<f32>
(%input_arg: tile<2xf32>, %input_accum: tile<f32>) {
%add_result = addf %input_arg, %input_accum : tile<f32>
yield %add_result : tile<f32>
}
}
}
12.3.15. cuda_tile.reduce_1#
cuda_tile.module @module {
entry @example() {
%input = constant <f32: 0.0> : tile<8x64xf32>
%0 = reduce %input dim=0 identities=[0.000000e+0 : f32] : tile<8x64xf32> -> tile<8xf32>
(%input_arg: tile<f32>, %input_accum: tile<f32>) {
%add_result = addf %input_arg, %input_accum : tile<f32>
yield %add_result : tile<f32>
}
}
}
12.3.16. cuda_tile.reshape_0#
cuda_tile.module @module {
entry @example() {
%cst = constant <i8: 0> : tile<i8>
%0 = reshape %cst
: tile<i8> -> tile<1x1x1xi8>
%t = constant <f32: 0.0> : tile<8x2xf32>
%1 = reshape %t
: tile<8x2xf32> -> tile<2x2x4x1xf32>
}
}
12.3.17. cuda_tile.reshape_1#
cuda_tile.module @module {
entry @example() {
%cst = constant <i32: [[0, 1, 2, 3], [4, 5, 6, 7]]>
: tile<2x4xi32>
%r0 = reshape %cst
: tile<2x4xi32> -> tile<2x2x2xi32>
// Step 1: Turn source into 1D tile. Use row-major by convention.
// %tmp: [0, 1, 2, 3, 4, 5, 6, 7]
%tmp = reshape %cst
: tile<2x4xi32> -> tile<8xi32>
// Step 2: Turn 1D tile into result tile. Use row-major by convention.
// %r: [[[0, 1], [2, 3]], [[4, 5], [6, 7]]]
%r1 = reshape %tmp
: tile<8xi32> -> tile<2x2x2xi32>
}
}
12.3.18. cuda_tile.scan_0#
cuda_tile.module @module {
entry @example() {
%input = constant <f32: 0.0> : tile<8x16xf32>
%result = scan %input dim=1 reverse=false identities=[1.0 : f32] : tile<8x16xf32> -> tile<8x16xf32>
(%acc: tile<f32>, %elem: tile<f32>) {
%prod = mulf %acc, %elem rounding<nearest_even>: tile<f32>
yield %prod : tile<f32>
}
}
}
12.3.19. cuda_tile.unpack_0#
cuda_tile.module @module {
entry @example() {
%arg0 = constant <i8: 0> : tile<64xi8>
%0 = unpack %arg0 : tile<64xi8> -> tile<32xf16>
}
}
12.3.20. cuda_tile.unpack_1#
cuda_tile.module @module {
entry @example() {
%arg0 = constant <i8: 0> : tile<64xi8>
%0 = unpack %arg0 : tile<64xi8> -> tile<128xF4E2M1FN>
}
}
12.3.21. cuda_tile.assert_0#
cuda_tile.module @module {
entry @example(%arg0: tile<i1>) {
assert %arg0, "assertion failed" : tile<i1>
}
}
12.3.22. cuda_tile.break_0#
cuda_tile.module @module {
entry @example() {
// Break from the body of a loop.
loop {
break
}
// Break from an if nested within the loop.
loop {
%condition = constant <i1: 1> : tile<i1>
if %condition {
break
}
// ...
}
%initValue0 = constant <f32: 0.0> : tile<f32>
// Break from an if nested within the loop, while yielding values.
%results = loop iter_values(%var0 = %initValue0): tile<f32> -> tile<f32> {
%condition = constant <i1: 1> : tile<i1>
if %condition {
// ...
yield
} else {
// %if.loopValue0 = ...
%loopValue0 = constant <f32: 1.0> : tile<f32>
break %loopValue0 : tile<f32>
}
%loopValue1 = constant <f32: 1.0> : tile<f32>
continue %loopValue1 : tile<f32>
}
}
}
12.3.23. cuda_tile.continue_0#
cuda_tile.module @module {
entry @example() {
%lowerBound = constant <i32: 0> : tile<i32>
%upperBound = constant <i32: 10> : tile<i32>
%step = constant <i32: 1> : tile<i32>
%condition = constant <i1: 1> : tile<i1>
// Continue from the body of a loop.
for %iv in (%lowerBound to %upperBound, step %step) : tile<i32> {
continue
}
// Continue from an if nested within the loop.
for %iv in (%lowerBound to %upperBound, step %step) : tile<i32> {
if %condition {
continue
}
// ...
}
// Continue from an if nested within the loop, while yielding values.
%initVar0 = constant <f32: 0.0> : tile<f32>
%results = for %iv in (%lowerBound to %upperBound, step %step) : tile<i32>
iter_values(%var0 = %initVar0) -> (tile<f32>)
{
if %condition {
// ...
yield
} else {
%loopValue0 = constant <f32: 1.0> : tile<f32>
continue %loopValue0 : tile<f32>
}
%loopValue1 = constant <f32: 1.0> : tile<f32>
continue %loopValue1 : tile<f32>
}
}
}
12.3.24. cuda_tile.for_0#
cuda_tile.module @module {
entry @example() {
%lowerBound = constant <i32: 0> : tile<i32>
%upperBound = constant <i32: 10> : tile<i32>
%step = constant <i32: 1> : tile<i32>
// A simple loop iterating over an i32 range.
for %iv in (%lowerBound to %upperBound, step %step) : tile<i32> {
continue
}
%initVal0 = constant <f32: 0.0> : tile<f32>
// A similar loop to the above, but with a loop carried value, val0.
%results = for %iv in (%lowerBound to %upperBound, step %step) : tile<i32>
iter_values(%val00 = %initVal0) -> (tile<f32>) {
%loopVal0 = constant <f32: 1.0> : tile<f32>
continue %loopVal0 : tile<f32>
}
}
}
12.3.25. cuda_tile.if_0#
cuda_tile.module @module {
entry @example() {
%condition = constant <i1: 1> : tile<i1>
// A simple if operation that conditionally executes a region.
if %condition {
// ...
}
// An if operation with an "else" branch.
if %condition {
// ...
} else {
// ...
}
// An if operation that returns mixed types (f32,i32)
%x, %y = if %condition -> (tile<f32>, tile<i32>) {
%x_then = constant <f32: 1.0> : tile<f32>
%y_then = constant <i32: 2> : tile<i32>
yield %x_then, %y_then : tile<f32>, tile<i32>
} else {
%x_then = constant <f32: 1.0> : tile<f32>
%y_then = constant <i32: 42> : tile<i32>
yield %x_then, %y_then : tile<f32>, tile<i32>
}
}
}
12.3.26. cuda_tile.loop_0#
cuda_tile.module @module {
entry @example() {
// A simple "while-do" loop.
loop {
%cond = constant <i1: 1> : tile<i1>
if %cond {
continue
}
break
}
}
}
12.3.27. cuda_tile.loop_1#
cuda_tile.module @module {
entry @example() {
// A simple "do-while" loop.
loop {
//... body of the loop.
%cond = constant <i1: 1> : tile<i1>
if %cond {
continue
}
break
}
}
}
12.3.28. cuda_tile.loop_2#
cuda_tile.module @module {
entry @example() {
%initValue0 = constant <f32: 0.0> : tile<f32>
// A loop that yields carried-iteration values, returning the final values.
%results = loop iter_values(%value0 = %initValue0) : tile<f32> -> tile<f32> {
%cond = constant <i1: 1> : tile<i1>
if %cond {
%loopValue0 = constant <f32: 0.0> : tile<f32>
continue %loopValue0 : tile<f32>
}
break %value0 : tile<f32>
}
}
}
12.3.29. cuda_tile.loop_3#
cuda_tile.module @module {
entry @example() {
%initValue0 = constant <i32: 0> : tile<i32>
// A loop that uses loop-carried values and returns a different type.
%results = loop iter_values(%value0 = %initValue0) : tile<i32> -> tile<f32> {
%cond = constant <i1: 1> : tile<i1>
if %cond {
%newLoopValue = constant <i32: 0> : tile<i32>
continue %newLoopValue : tile<i32>
}
%finalReturnValue = constant <f32: 0.0> : tile<f32>
break %finalReturnValue : tile<f32>
}
}
}
12.3.30. cuda_tile.return_0#
cuda_tile.module @module {
experimental$func @foo() -> (tile<i32>, tile<f16>) {
%0 = constant <i32: 0> : tile<i32>
%1 = constant <f16: 0.0> : tile<f16>
// ...
return %0, %1 : tile<i32>, tile<f16>
}
}
12.3.31. cuda_tile.return_1#
cuda_tile.module @module {
entry @foo() {
%0 = constant <i32: 0> : tile<i32>
%1 = constant <f16: 0.0> : tile<f16>
// ...
return
}
}
12.3.32. cuda_tile.yield_0#
cuda_tile.module @module {
entry @example() {
%condition = constant <i1: true> : tile<i1>
// Yield from the body of an if conditional.
if %condition {
yield
}
// Yield values from within an if conditional.
%x, %y = if %condition -> (tile<f32>, tile<f32>) {
%x_then = constant <f32: 0.0> : tile<f32>
%y_then = constant <f32: 1.0> : tile<f32>
yield %x_then, %y_then : tile<f32>, tile<f32>
} else {
%x_else = constant <f32: 2.0> : tile<f32>
%y_else = constant <f32: 3.0> : tile<f32>
yield %x_else, %y_else : tile<f32>, tile<f32>
}
}
}
12.3.33. cuda_tile.load_ptr_tko_0#
cuda_tile.module @module {
entry @example(%ptr: tile<ptr<f32>>) {
%mask = constant <i1: 1> : tile<i1>
%padding = constant <f32: 0.0> : tile<f32>
// Load without token.
%result0, %res_token0 = load_ptr_tko weak %ptr, %mask, %padding
: tile<ptr<f32>>, tile<i1>, tile<f32> -> tile<f32>, token
// Load with token.
%token0 = make_token : token
%result1, %res_token1 = load_ptr_tko weak %ptr, %mask, %padding token=%token0
: tile<ptr<f32>>, tile<i1>, tile<f32> -> tile<f32>, token
return
}
}
12.3.34. cuda_tile.atan2_0#
cuda_tile.module @ex_module {
entry @example_atan2() {
%x = constant <f32: [1.0, -1.0, 0.0, 2.0]> : tile<4xf32>
%y = constant <f32: [1.0, 1.0, 1.0, 0.0]> : tile<4xf32>
%res = atan2 %x, %y : tile<4xf32>
}
}
12.3.35. cuda_tile.ceil_0#
cuda_tile.module @module {
entry @example() {
%source = constant <f32: 0.5> : tile<f32>
%result = ceil %source : tile<f32>
}
}
12.3.36. cuda_tile.cos_0#
cuda_tile.module @ex_module {
entry @example_cos() {
%in = constant <f32: [0.0, 1.0, 2.0, 3.0]> : tile<4xf32>
%res = cos %in : tile<4xf32>
}
}
12.3.37. cuda_tile.exp2_0#
cuda_tile.module @ex_module {
entry @example_exp2() {
%in = constant <f32: [0.0, 1.0, 2.0, 3.0]> : tile<4xf32>
%res = exp2 %in : tile<4xf32>
}
}
12.3.38. cuda_tile.exp_0#
cuda_tile.module @ex_module {
entry @example_exp() {
%in = constant <f32: [0.0, 1.0, 2.0, 3.0]> : tile<4xf32>
%res = exp %in : tile<4xf32>
}
}
12.3.39. cuda_tile.floor_0#
cuda_tile.module @module {
entry @example() {
%source = constant <f32: 1.5> : tile<f32>
%result = floor %source : tile<f32>
}
}
12.3.40. cuda_tile.log2_0#
cuda_tile.module @ex_module {
entry @example_log2() {
%in = constant <f32: [0.0, 1.0, 2.0, 3.0]> : tile<4xf32>
%res = log2 %in : tile<4xf32>
}
}
12.3.41. cuda_tile.maxf_0#
cuda_tile.module @module {
entry @example_maxf(%arg0: tile<ptr<f32>>, %arg1: tile<ptr<f32>>) {
// Create tensor view from a pointer to global memory
%0 = make_tensor_view %arg0, shape = [2, 4], strides = [4, 1] : tensor_view<2x4xf32, strides=[4,1]>
%1 = make_tensor_view %arg1, shape = [2, 4], strides = [4, 1] : tensor_view<2x4xf32, strides=[4,1]>
// Convert tensor views to partition views and load tiles from partition views.
%p0 = make_partition_view %0 : partition_view<tile=(2x4), tensor_view<2x4xf32, strides=[4,1]>>
%p1 = make_partition_view %1 : partition_view<tile=(2x4), tensor_view<2x4xf32, strides=[4,1]>>
%c0 = constant <i32: 0> : tile<i32>
%2, %token0 = load_view_tko weak %p0[%c0, %c0] : partition_view<tile=(2x4), tensor_view<2x4xf32, strides=[4,1]>>, tile<i32> -> tile<2x4xf32>, token
%3, %token1 = load_view_tko weak %p1[%c0, %c0] : partition_view<tile=(2x4), tensor_view<2x4xf32, strides=[4,1]>>, tile<i32> -> tile<2x4xf32>, token
// IEEE 754-2019's maximum
%4 = maxf %2, %3 propagate_nan : tile<2x4xf32>
// IEEE 754-2019's maximumNumber
%5 = maxf %2, %3 : tile<2x4xf32>
// flush denormal to positive zero
%6 = maxf %2, %3 flush_to_zero : tile<2x4xf32>
}
}
12.3.42. cuda_tile.minf_0#
cuda_tile.module @module {
entry @example_minf(%arg0: tile<ptr<f32>>, %arg1: tile<ptr<f32>>) {
// Create tensor view from a pointer to global memory
%0 = make_tensor_view %arg0, shape = [2, 4], strides = [4, 1] : tensor_view<2x4xf32, strides=[4,1]>
%1 = make_tensor_view %arg1, shape = [2, 4], strides = [4, 1] : tensor_view<2x4xf32, strides=[4,1]>
// Convert tensor views to partition views and load tiles from partition views.
%p0 = make_partition_view %0 : partition_view<tile=(2x4), tensor_view<2x4xf32, strides=[4,1]>>
%p1 = make_partition_view %1 : partition_view<tile=(2x4), tensor_view<2x4xf32, strides=[4,1]>>
%c0 = constant <i32: 0> : tile<i32>
%2, %token0 = load_view_tko weak %p0[%c0, %c0] : partition_view<tile=(2x4), tensor_view<2x4xf32, strides=[4,1]>>, tile<i32> -> tile<2x4xf32>, token
%3, %token1 = load_view_tko weak %p1[%c0, %c0] : partition_view<tile=(2x4), tensor_view<2x4xf32, strides=[4,1]>>, tile<i32> -> tile<2x4xf32>, token
// IEEE 754-2019's minimum
%4 = minf %2, %3 propagate_nan : tile<2x4xf32>
// IEEE 754-2019's minimumNumber
%5 = minf %2, %3 : tile<2x4xf32>
// flush denormal to positive zero
%6 = minf %2, %3 flush_to_zero : tile<2x4xf32>
}
}
12.3.43. cuda_tile.negf_0#
cuda_tile.module @module {
entry @example() {
%source = constant <f32: 0.0> : tile<4xf32>
%result = negf %source : tile<4xf32>
}
}
12.3.44. cuda_tile.pow_0#
cuda_tile.module @module {
entry @example() {
%source = constant <f32: 0.0> : tile<4xf32>
%exponent = constant <f32: 2.0> : tile<4xf32>
%result = pow %source, %exponent : tile<4xf32>
}
}
12.3.45. cuda_tile.rsqrt_0#
cuda_tile.module @ex_module {
entry @example_rsqrt() {
%in = constant <f32: [0.0, 1.0, 2.0, 3.0]> : tile<4xf32>
%res = rsqrt %in : tile<4xf32>
// Rsqrt op with flush to zero modifier
%ftz_res = rsqrt %in flush_to_zero : tile<4xf32>
}
}
12.3.46. cuda_tile.sin_0#
cuda_tile.module @ex_module {
entry @example_sin() {
%in = constant <f32: [0.0, 1.0, 2.0, 3.0]> : tile<4xf32>
%res = sin %in : tile<4xf32>
}
}
12.3.47. cuda_tile.tanh_0#
cuda_tile.module @ex_module {
entry @example_tanh() {
%in = constant <f32: [0.0, 1.0, 2.0, 3.0]> : tile<4xf32>
%res0 = tanh %in : tile<4xf32>
// tanh with approx modifier
%res1 = tanh %in rounding<approx> : tile<4xf32>
}
}
12.3.48. cuda_tile.maxi_0#
cuda_tile.module @module {
entry @example_maxi(%arg0: tile<ptr<i32>>, %arg1: tile<ptr<i32>>) {
// Create tensor view from a pointer to global memory
%0 = make_tensor_view %arg0, shape = [2, 4], strides = [4, 1] : tensor_view<2x4xi32, strides=[4,1]>
%1 = make_tensor_view %arg1, shape = [2, 4], strides = [4, 1] : tensor_view<2x4xi32, strides=[4,1]>
// Convert tensor views to partition views and load tiles from them.
%p0 = make_partition_view %0 : partition_view<tile=(2x4), tensor_view<2x4xi32, strides=[4,1]>>
%p1 = make_partition_view %1 : partition_view<tile=(2x4), tensor_view<2x4xi32, strides=[4,1]>>
%c0 = constant <i32: 0> : tile<i32>
%2, %token0 = load_view_tko weak %p0[%c0, %c0] : partition_view<tile=(2x4), tensor_view<2x4xi32, strides=[4,1]>>, tile<i32> -> tile<2x4xi32>, token
%3, %token1 = load_view_tko weak %p1[%c0, %c0] : partition_view<tile=(2x4), tensor_view<2x4xi32, strides=[4,1]>>, tile<i32> -> tile<2x4xi32>, token
// Signless i32 treated as unsigned
%4 = maxi %2, %3 unsigned : tile<2x4xi32>
// Signless i32 treated as signed
%5 = maxi %2, %3 signed : tile<2x4xi32>
}
}
12.3.49. cuda_tile.mini_0#
cuda_tile.module @module {
entry @example_mini(%arg0: tile<ptr<i32>>, %arg1: tile<ptr<i32>>) {
// Create tensor view from a pointer to global memory
%0 = make_tensor_view %arg0, shape = [2, 4], strides = [4, 1] : tensor_view<2x4xi32, strides=[4,1]>
%1 = make_tensor_view %arg1, shape = [2, 4], strides = [4, 1] : tensor_view<2x4xi32, strides=[4,1]>
// Convert tensor views to partition views and load tiles from partition views.
%p0 = make_partition_view %0 : partition_view<tile=(2x4), tensor_view<2x4xi32, strides=[4,1]>>
%p1 = make_partition_view %1 : partition_view<tile=(2x4), tensor_view<2x4xi32, strides=[4,1]>>
%c0 = constant <i32: 0> : tile<i32>
%2, %token0 = load_view_tko weak %p0[%c0, %c0] : partition_view<tile=(2x4), tensor_view<2x4xi32, strides=[4,1]>>, tile<i32> -> tile<2x4xi32>, token
%3, %token1 = load_view_tko weak %p1[%c0, %c0] : partition_view<tile=(2x4), tensor_view<2x4xi32, strides=[4,1]>>, tile<i32> -> tile<2x4xi32>, token
// Signless i32 treated as unsigned
%4 = mini %2, %3 unsigned : tile<2x4xi32>
// Signless i32 treated as signed
%5 = mini %2, %3 signed : tile<2x4xi32>
}
}
12.3.50. cuda_tile.mulhii_0#
cuda_tile.module @module {
entry @example() {
// 2^31 * 2 = 2^32, or 0x100000000.
// The most significant 32 bits of the product are 0x00000001.
// The lower 32 bits of the product are 0x00000000.
%a = constant <i32: 2147483648> : tile<i32> // %a = 2^31
%b = constant <i32: 2> : tile<i32> // %b = 2
%res_hi = mulhii %a, %b : tile<i32> // %res_hi = 1
%res_lo = muli %a, %b : tile<i32> // %res_lo = 0
}
}
12.3.51. cuda_tile.negi_0#
cuda_tile.module @module {
entry @example() {
%source = constant <i16: [0, 1, 2, 3]> : tile<4xi16>
%result = negi %source : tile<4xi16>
// %result = [0, -1, -2, -3]
}
}
12.3.52. cuda_tile.xori_0#
cuda_tile.module @module {
entry @example() {
%lhs = constant <i32: [0, 1, 2, 3]> : tile<4xi32>
%rhs = constant <i32: [4, 5, 6, 7]> : tile<4xi32>
// This computes the bitwise XOR of each element in `%lhs` and `%rhs`, which
// are tiles of shape `4xi32`, and returns the result as `%result`.
%result = xori %lhs, %rhs : tile<4xi32>
}
}
12.3.53. cuda_tile.atomic_cas_tko_0#
cuda_tile.module @ex_module {
entry @example(%ptr: tile<ptr<i32>>) {
%ptr_1x = reshape %ptr : tile<ptr<i32>> -> tile<1xptr<i32>>
%ptr_vec = broadcast %ptr_1x : tile<1xptr<i32>> -> tile<8xptr<i32>>
%offsets = iota : tile<8xi32>
%ptrs = offset %ptr_vec, %offsets : tile<8xptr<i32>>, tile<8xi32> -> tile<8xptr<i32>>
%cmp = constant <i32: [0, 1, 2, 3, 4, 5, 6, 7]> : tile<8xi32>
%val = constant <i32: [7, 6, 5, 4, 3, 2, 1, 0]> : tile<8xi32>
%mask = constant <i1: [0, 1, 0, 1, 0, 1, 0, 1]> : tile<8xi1>
// Atomic CAS without input token.
%0, %token = atomic_cas_tko relaxed device %ptrs, %cmp, %val :
tile<8xptr<i32>>, tile<8xi32> -> tile<8xi32>, token
// Atomic CAS without input token.
%1, %token1 = atomic_cas_tko relaxed device %ptrs, %cmp, %val, %mask :
tile<8xptr<i32>>, tile<8xi32>, tile<8xi1> -> tile<8xi32>, token
// Atomic CAS with input token.
%token2 = make_token : token
%2, %token3 = atomic_cas_tko relaxed device %ptrs, %cmp, %val token=%token2 :
tile<8xptr<i32>>, tile<8xi32> -> tile<8xi32>, token
return
}
}
12.3.54. cuda_tile.atomic_rmw_tko_0#
cuda_tile.module @ex_module {
entry @example_rmw(%ptr: tile<ptr<f32>>) {
// Reshape the input pointer tile to have a 1d shape
%ptr_1x = reshape %ptr : tile<ptr<f32>> -> tile<1xptr<f32>>
// Broadcast the reshaped tile to a tile with 8 rows, effectively replicating the pointer 8 times
%ptr_vec = broadcast %ptr_1x : tile<1xptr<f32>> -> tile<8xptr<f32>>
// Create a tile of offsets [0, 1, 2, ..., 7] to index into memory
%offsets = iota : tile<8xi32>
// Add the offsets to each pointer in the vector to create 8 unique pointers
%ptrs = offset %ptr_vec, %offsets : tile<8xptr<f32>>, tile<8xi32> -> tile<8xptr<f32>>
%vals = constant <f32: [7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0]> : tile<8xf32>
// Perform atomic addf operations on the memory locations pointed by %ptrs
// without requiring an input token. Returns the original values and a result token
%0, %res_token0 = atomic_rmw_tko relaxed device %ptrs, addf, %vals :
tile<8xptr<f32>>, tile<8xf32> -> tile<8xf32>, token
// Perform atomic add operations again, this time using the explicit input token
%token = make_token : token
%1, %res_token1 = atomic_rmw_tko relaxed device %ptrs, addf, %vals, token = %token :
tile<8xptr<f32>>, tile<8xf32> -> tile<8xf32>, token
}
}
12.3.55. cuda_tile.get_index_space_shape_0#
cuda_tile.module @module {
entry @example(%base: tile<ptr<f32>>) {
%tensor_view = make_tensor_view %base,
shape = [2, 2, 4], strides = [2, 2, 1]
: tensor_view<2x2x4xf32, strides=[2,2,1]>
%partition_view = make_partition_view %tensor_view :
partition_view<tile=(2x2x4), tensor_view<2x2x4xf32, strides=[2,2,1]>>
%dim0, %dim1, %dim2 = get_index_space_shape %partition_view :
partition_view<tile=(2x2x4), tensor_view<2x2x4xf32, strides=[2,2,1]>> -> tile<i64>
}
}
12.3.56. cuda_tile.get_tensor_shape_0#
cuda_tile.module @module {
entry @example(%base: tile<ptr<f32>>) {
%tensor_view = make_tensor_view %base,
shape = [32, 32], strides = [32, 1]
: tensor_view<32x32xf32, strides=[32,1]>
%dim0, %dim1 = get_tensor_shape %tensor_view : tensor_view<32x32xf32, strides=[32,1]> -> tile<i64>
}
}
12.3.57. cuda_tile.load_view_tko_0#
cuda_tile.module @module {
entry @example(%ptr: tile<ptr<f32>>, %index: tile<i32>) {
%tensor_view = make_tensor_view %ptr, shape=[8192, 128], strides=[128, 1]
: tensor_view<8192x128xf32, strides=[128,1]>
// This example uses the PartitionView on a 8192x128xf32 tensor_view,
// dividing the tensor_view in tiles of 64x64.
%view = make_partition_view %tensor_view : partition_view<tile=(64x64), tensor_view<8192x128xf32, strides=[128,1]>>
%c0 = constant <i32: 0> : tile<i32>
%c1 = constant <i32: 1> : tile<i32>
// Load a tile at index (0, 0) in the view's index space.
// For this PartitionView, this is the rectangular tile such that
// X=[0,64) and Y=[0,64), in the coordinates of tiles.
%tile0, %res_token0 = load_view_tko weak %view[%c0, %c0]
: partition_view<tile=(64x64), tensor_view<8192x128xf32, strides=[128,1]>>, tile<i32> -> tile<64x64xf32>, token
// Load a tile at index (0, 1) in the view's index space.
// For this PartitionView, this is the rectangular tile such that
// X=[0,64) and Y=[64,128), in the coordinates of tiles.
%tile1, %res_token1 = load_view_tko weak %view[%c0, %c1]
: partition_view<tile=(64x64), tensor_view<8192x128xf32, strides=[128,1]>>, tile<i32> -> tile<64x64xf32>, token
// Same example as above but with memory token as input.
%token = make_token : token
%tile2, %res_token2 = load_view_tko weak %view[%c0, %c1] token = %token
: partition_view<tile=(64x64), tensor_view<8192x128xf32, strides=[128,1]>>, tile<i32> -> tile<64x64xf32>, token
// Loads a tile at the dynamic index (%index, %index) in the view's index space.
%tile3, %res_token3 = load_view_tko weak %view[%index, %index]
: partition_view<tile=(64x64), tensor_view<8192x128xf32, strides=[128,1]>>, tile<i32> -> tile<64x64xf32>, token
}
}
12.3.58. cuda_tile.make_partition_view_0#
cuda_tile.module @module {
entry @example(%ptr: tile<ptr<f32>>) {
%tensor_view0 = make_tensor_view %ptr, shape=[8192, 8192, 64], strides=[524288,64,1]
: tensor_view<8192x8192x64xf32, strides=[524288,64,1]>
// Creates a partition with 32-bit-indexed tiles of size (1024x1x32) over
// the provided tensor_view.
make_partition_view %tensor_view0 :
partition_view<
tile=(1024x1x32),
tensor_view<8192x8192x64xf32, strides=[524288,64,1]>
>
%s0 = constant <i32: 8192> : tile<i32>
%str0 = constant <i32: 524288> : tile<i32>
// These seems very wrong.
%tensor_view1 = make_tensor_view %ptr, shape=[%s0, 8192, 64], strides=[%str0, 64, 1]
: tile<i32> -> tensor_view<?x8192x64xf32, strides=[?,64,1]>
// Creates a partition with 32-bit-indexed tiles of size (1024x1x32) over
// the provided tensor_view, with masking. The provided tensor_view has a
// dynamically-sized dimension.
make_partition_view %tensor_view1 :
partition_view<tile=(1024x1x32), tensor_view<?x8192x64xf32, strides=[?,64,1]>>
}
}
12.3.59. cuda_tile.make_tensor_view_0#
cuda_tile.module @module {
entry @example(%base: tile<ptr<f32>>) {
// tensor_view to a scalar tile of f32
%a0 = make_tensor_view %base,
shape = [], strides = [] : tensor_view<f32>
// tensor_view to a tile of static shape and strides
%a1 = make_tensor_view %base,
shape = [32, 32], strides = [32, 1]
: tensor_view<32x32xf32, strides=[32,1]>
%sh0 = constant <i32: 32> : tile<i32>
%sh1 = constant <i32: 32> : tile<i32>
%st0 = constant <i32: 32> : tile<i32>
%st1 = constant <i32: 1> : tile<i32>
// tensor_view to a tile with partially dynamic shape and strides
// all dynamic values must be of the same type, here tile<i32>
%a2 = make_tensor_view %base,
shape = [%sh0, %sh1], strides = [%st0, %st1]
: tile<i32> -> tensor_view<?x?xf32, strides=[?,?]>
}
}
12.3.60. cuda_tile.store_view_tko_0#
cuda_tile.module @module {
entry @example(%ptr: tile<ptr<f32>>) {
%tensor_view = make_tensor_view %ptr, shape=[8192, 128], strides=[128,1] :
tensor_view<8192x128xf32, strides=[128,1]>
// This example uses the PartitionView on a 8192x128xf32 tensor_view,
// dividing the tensor_view in tiles of 64x64.
%view = make_partition_view %tensor_view :
partition_view<tile=(64x64), tensor_view<8192x128xf32, strides=[128,1]>>
%c0 = constant <i32: 0> : tile<i32>
%c1 = constant <i32: 1> : tile<i32>
%tile = constant <f32: 0.0> : tile<64x64xf32>
// Store a tile at index (0, 0) in the view's index space.
// For this TilePartitionView, this is the rectangular tile such that
// X=[0,64) and Y=[0,64), in the coordinates of tiles.
%res_token0 = store_view_tko weak %tile, %view[%c0, %c0]
: tile<64x64xf32>, partition_view<tile=(64x64), tensor_view<8192x128xf32, strides=[128,1]>>, tile<i32> -> token
// Store a tile at index (0, 1) in the view's index space.
// For this PartitionView, this is the rectangular tile such that
// X=[0,64) and Y=[64,128), in the coordinates of tiles.
%res_token1 = store_view_tko weak %tile, %view[%c0, %c1]
: tile<64x64xf32>, partition_view<tile=(64x64), tensor_view<8192x128xf32, strides=[128,1]>>, tile<i32> -> token
// Same example as above but with input token.
%token = make_token : token
%res_token2 = store_view_tko weak %tile, %view[%c0, %c1] token = %token
: tile<64x64xf32>, partition_view<tile=(64x64), tensor_view<8192x128xf32, strides=[128,1]>>, tile<i32> -> token
}
}
12.3.61. cuda_tile.assume_example_Bounded_0#
%1 = cuda_tile.assume #cuda_tile.bounded<0, ?>, %0
: !cuda_tile.tile<4x8xi16>
12.3.62. cuda_tile.assume_example_DivBy_0#
// Example 1: Each pointer is divisible by 16.
// [ 0x10, 0x20, 0x80, 0x10, 0x0, 0x120, ... ]
%0 = cuda_tile.assume #cuda_tile.div_by<16>, %ptrs
: !cuda_tile.tile<128x!cuda_tile.ptr<f32>>
// Note: Equivalent to #cuda_tile.div_by<16, every 1 along 0>.
12.3.63. cuda_tile.assume_example_DivBy_1#
// Example 2: Each integer is divisible by 4.
// [ 16, 24, 8, 4, 12, 12, 0, 16, ... ]
%0 = cuda_tile.assume #cuda_tile.div_by<4>, %t
: !cuda_tile.tile<128xi32>
12.3.64. cuda_tile.assume_example_DivBy_2#
// Example 3: Group size [4].
// [7, 8, 9, 10, 23, 24, 25, 26, 0, 1, 2, 3, ...]
%0 = cuda_tile.assume #cuda_tile.div_by<1, every 4 along 0>, %t
: !cuda_tile.tile<128xi32>
12.3.65. cuda_tile.assume_example_DivBy_3#
// Example 4: 2-d Group size [1, 4] with divisibility 4.
// [ [ 4, 5, 6, 7, 12, 13, 14, 15 ],
// [ 8, 9, 10, 11, 24, 25, 26, 27 ],
// [ 24, 25, 26, 27, 64, 65, 66, 67 ],
// [ 0, 1, 2, 3, 4, 5, 6, 7 ] ]
%0 = cuda_tile.assume #cuda_tile.div_by<4, every 4 along 1>, %t
: !cuda_tile.tile<4x8xi32>
12.3.66. cuda_tile.assume_example_DivBy_4#
// Example 5: 2-d Group size [4, 1] with divisibility 32.
// Note that the elements within each column are monotonically increasing
// by the byte width of the pointee type f32, e.g., 0x20, 0x24, 0x28, 0x2c.
// [ [ 0x20, 0x100, 0x40, 0x60, 0x40, 0x200, 0x340, 0x40 ],
// [ 0x24, 0x104, 0x44, 0x64, 0x44, 0x204, 0x344, 0x44 ],
// [ 0x28, 0x108, 0x48, 0x68, 0x48, 0x208, 0x348, 0x48 ],
// [ 0x2c, 0x10c, 0x4c, 0x6c, 0x4c, 0x20c, 0x34c, 0x4c ] ]
%0 = cuda_tile.assume #cuda_tile.div_by<32, every 4 along 0>, %ptrs
: !cuda_tile.tile<4x8x!cuda_tile.ptr<f32>>
12.3.67. cuda_tile.assume_example_SameElements_0#
// Integer tensor with same elements.
%0 = cuda_tile.constant <i16: [[0, 0, 0, 0, 10, 10, 10, 10],
[0, 0, 0, 0, 10, 10, 10, 10],
[5, 5, 5, 5, 93, 93, 93, 93],
[5, 5, 5, 5, 93, 93, 93, 93]]>
: tile<4x8xi16>
%1 = cuda_tile.assume #cuda_tile.same_elements<[2, 4]>, %0
: !cuda_tile.tile<4x8xi16>
// Pointer tensor with same elements.
%2 = cuda_tile.constant <i64: [[ 0, 0, 0, 0, 8, 8, 8, 8],
[ 0, 0, 0, 0, 8, 8, 8, 8],
[64, 64, 64, 64, 32, 32, 32, 32],
[64, 64, 64, 64, 32, 32, 32, 32]]>
: tile<4x8xi64>
%3 = cuda_tile.bitcast %2
: !cuda_tile.tile<4x8xi64>
-> !cuda_tile.tile<!cuda_tile.ptr<f32>>
%4 = cuda_tile.assume #cuda_tile.same_elements<[2, 4]>, %3
: !cuda_tile.tile<!cuda_tile.ptr<f32>>
12.3.68. cuda_tile.assume_0#
cuda_tile.module @module {
entry @example(%input: tile<ptr<f32>>) {
// Assume that all integers are divisible by 32.
%int_tile = constant <i16: [32, 64, 0, 0, 32, -32, 1024, 0]> : tile<8xi16>
%div_by_1 = assume div_by<32>, %int_tile : tile<8xi16>
// Assume that every 4th element (starting with element 0) along
// dimension 0 is divisible by 32 that and all integers are
// montonically increasing by 1 within each group of 4.
%int_tile_2 = constant <i16: [96, 97, 98, 99, 64, 65, 66, 67]> : tile<8xi16>
%div_by_2 = assume div_by<32, every 4 along 0>, %int_tile_2 : tile<8xi16>
// Assume that every rectangular chunk of size [1, 4, 2] has the same
// values.
%input_rank3 = reshape %input : tile<ptr<f32>> -> tile<1x1x1xptr<f32>>
%ptr_3d = broadcast %input_rank3 : tile<1x1x1xptr<f32>> -> tile<1x8x8xptr<f32>>
%same_elem = assume same_elements<[1, 4, 2]>, %ptr_3d : tile<1x8x8xptr<f32>>
// Assume that every value is greater or equal to 5.
%int_tile_3 = constant <i16: [5, 9, 10, 11, 6, 5, 5, 7]> : tile<8xi16>
%bounded = assume bounded<5, ?>, %int_tile_3 : tile<8xi16>
}
}
12.3.69. cuda_tile.print_tko_0#
cuda_tile.module @module {
entry @example() {
%arg = constant <f32: 0.0> : tile<4xf32>
print_tko "Hello world: %f\n", %arg : tile<4xf32> -> token
print_tko "%+08.3f", %arg : tile<4xf32> -> token
}
}
12.3.70. cuda_tile.experimental$alloca_0#
cuda_tile.module @ex_module {
entry @free_kernel() {
%0 = experimental$alloca num_elem = 64, alignment = 128 : tile<ptr<f32>>
%1 = experimental$alloca num_elem = 64, alignment = 128 global : tile<ptr<f32>>
}
}
12.3.71. cuda_tile.experimental$asin_0#
cuda_tile.module @ex_module {
entry @example_asin() {
%in = constant <f32: [0.0, 1.0, 2.0, 3.0]> : tile<4xf32>
// %res = cuda_tile.experimental.asin %in : tile<4xf32>
}
}
12.3.72. cuda_tile.experimental$call_0#
cuda_tile.module @my_module {
experimental$func @my_add(%0: !cuda_tile.tile<f32>, %1: !cuda_tile.tile<f32>) -> !cuda_tile.tile<f32> {
%2 = addf %0, %1 rounding<zero> : !cuda_tile.tile<f32>
return %2 : !cuda_tile.tile<f32>
}
entry @my_entry() {
%0 = cuda_tile.constant <f32: 1.0> : !cuda_tile.tile<f32>
%1 = cuda_tile.constant <f32: 1.0> : !cuda_tile.tile<f32>
%2 = cuda_tile.experimental$call @my_add(%0, %1)
: (!cuda_tile.tile<f32>, !cuda_tile.tile<f32>) -> !cuda_tile.tile<f32>
}
}
12.3.73. cuda_tile.experimental$create_queue_0#
// %0 = cuda_tile.experimental.create_queue : !cuda_tile.queue<[...], depth=3>
12.3.74. cuda_tile.experimental$extern_elementwise_0#
cuda_tile.module @module {
entry @example() {
%arg = constant <f32: [0.0, 1.0, 2.0, 3.0]> : tile<4xf32>
%0 = experimental$extern_elementwise %arg {libname = "libname",
libpath = "/path/to/lib", pure = true, symbol = "symbol"} : (!cuda_tile.tile<4xf32>) -> !cuda_tile.tile<4xf32>
}
}
12.3.75. cuda_tile.experimental$gather_load_0#
// %res = "cuda_tile.experimental$gather_load"(%src, %index, %offset)
// <{dim = 0 : i64, fillValue = #cuda_tile<fill_value zero>}> :
// (!cuda_tile.tensor_view<128x1024xf32, strides=[1024,1]>,
// !cuda_tile.tile<4xi32>, i32) -> !cuda_tile.tile<4x64xf32>
12.3.76. cuda_tile.experimental$generate_0#
cuda_tile.module @my_module {
entry @my_entry() {
%0 = cuda_tile.experimental$generate {
^bb0(%arg0: !cuda_tile.tile<i32>, %arg1: !cuda_tile.tile<i32>):
%1 = cuda_tile.addi %arg0, %arg1 : !cuda_tile.tile<i32>
cuda_tile.yield %1 : !cuda_tile.tile<i32>
} : !cuda_tile.tile<16x16xi32>
}
}
12.3.77. cuda_tile.experimental$scatter_store_0#
cuda_tile.module @module {
entry @example(%out_ptr: tile<ptr<f32>>) {
%src = constant <f32: 0.0> : tile<8x32xf32>
%dst = make_tensor_view %out_ptr, shape=[64, 128], strides=[128,1] : tensor_view<64x128xf32, strides=[128,1]>
%index = constant <i32: [10, 11, 12, 13, 17, 18, 19, 20]> : tile<8xi32>
%offset = constant <i32: 24> : tile<i32>
%token = make_token : token
%result_token = experimental$scatter_store %src, %dst, %index, [%offset] token = %token <{dim = 0 : i64}> :
tile<8x32xf32>, tensor_view<64x128xf32, strides=[128,1]>, tile<8xi32>, tile<i32>, token -> token
}
}
12.3.78. cuda_tile.internal$optimization_barrier_0#
cuda_tile.module @my_module {
entry @my_entry() {
%0 = cuda_tile.constant <i8: 1> : !cuda_tile.tile<i8>
// Prevent two reshapes from folding away.
%1 = cuda_tile.reshape %0
: !cuda_tile.tile<i8> -> !cuda_tile.tile<1xi8>
%2 = cuda_tile.internal$optimization_barrier %1 : !cuda_tile.tile<1xi8>
%3 = cuda_tile.reshape %2
: !cuda_tile.tile<1xi8> -> !cuda_tile.tile<i8>
// Remove attached or inferred axis analysis information.
// divisibility = [int_max], contiguity = [16]
%4 = cuda_tile.iota : !cuda_tile.tile<16xi32>
// divisibility = [1], contiguity = [1]
%5 = cuda_tile.internal$optimization_barrier %4, keep_axis_info
: !cuda_tile.tile<16xi32>
}
}