Introduction#
Overview#
CuTe DSL is a Python-based domain-specific language (DSL) designed for dynamic compilation of numeric and GPU-oriented code. Its primary goals are:
Consistent with CuTe C++, allowing users to express GPU kernels with full control of the hardware.
JIT compilation for both host and GPU execution.
DLPack integration, enabling seamless interop with frameworks (e.g., PyTorch, JAX).
JIT caching, so that repeated calls to the same function benefit from cached IR modules.
Native types and type inference to reduce boilerplate and improve performance.
Optional lower-level control, offering direct access to GPU backends or specialized IR dialects.
Decorators#
CuTe DSL provides two main Python decorators for generating optimized code via dynamic compilation:
@jit
— Host-side JIT-compiled functions@kernel
— GPU kernel functions
Both decorators can optionally use a preprocessor that automatically expands Python control flow (loops, conditionals) into operations consumable by the underlying IR.
@jit
#
Declares JIT-compiled functions that can be invoked from Python or from other CuTe DSL functions.
Decorator Parameters:
preprocessor
:True
(default) — Automatically translate Python flow control (e.g., loops, if-statements) into IR operations.False
— No automatic expansion; Python flow control must be handled manually or avoided.
Call-site Parameters:
no_cache
:True
— Disables JIT caching, forcing a fresh compilation each call.False
(default) — Enables caching for faster subsequent calls.
@kernel
#
Defines GPU kernel functions, compiled as specialized GPU symbols through dynamic compilation.
Decorator Parameters:
preprocessor
:True
(default) — Automatically expands Python loops/ifs into GPU-compatible IR operations.False
— Expects manual or simplified kernel implementations.
Kernel Launch Parameters:
grid
Specifies the grid size as a list of integers.block
Specifies the block size as a list of integers.cluster
Specifies the cluster size as a list of integers.smem
Specifies the size of shared memory in bytes (integer).
Calling Conventions#
Caller |
Callee |
Allowed |
Compilation/Runtime |
---|---|---|---|
Python function |
|
✅ |
DSL runtime |
Python function |
|
❌ |
N/A (error raised) |
|
|
✅ |
Compile-time call, inlined |
|
Python function |
✅ |
Compile-time call, inlined |
|
|
✅ |
Dynamic call via GPU driver or runtime |
|
|
✅ |
Compile-time call, inlined |
|
Python function |
✅ |
Compile-time call, inlined |
|
|
❌ |
N/A (error raised) |