Integration with Frameworks#

In order to facilitate the integration of CUTLASS Python with popular frameworks, we leverage the DLPack protocol and transform tensors originating from these frameworks to CuTe tensors. The present page documents the conventions, the API available to the user, and provide example code snippets for common usage patterns.

Implicit Conversion#

Tensors originating from frameworks supporting the DLPack protocol can be directly provided to a JIT function as a regular parameter. CuTe DSL’s runtime implicitly converts the original tensor to a CuTe tensor with a fully dynamic layout except for the stride element corresponding to the leading dimension. The example below demonstrates this use case.

import torch
import cutlass.cute as cute

@cute.jit
def foo(src):
    """
    The following lines print

    ptr<f32, generic> o (?,?,?):(?,?,1)
    <class 'cutlass.cute.core._Tensor'>
    """
    print(src)
    print(type(src))

a = torch.randn(30, 20, 32, device="cpu")
foo(a)

Explicit conversion using from_dlpack#

CuTe DSL’s runtime provides an interface for converting DLPack-compatible tensors to CuTe tensors,

b = cute.runtime.from_dlpack(a)

where a is a tensor supporting the DLPack protocol with the __dlpack__ and __dlpack_device__ methods. The resulting CuTe tensor b has a fully static layout. This conversion is performed without copying any tensor data, enabling seamless integration with major frameworks. Users can create tensors using NumPy, PyTorch, etc. and directly feed them into JIT functions writtnen using CuTe DSL.

The resulting CuTe tensor shares the same underlying memory buffer as the original tensor. This zero-copy approach maximizes performance by eliminating unnecessary data duplication. However, it is important to note that the CuTe tensor’s validity is tied to the lifetime of the original tensor. If the source tensor is destroyed or goes out of scope, the corresponding CuTe tensor becomes invalid since it references the original memory location.

The full signature of from_dlpack is as follows:

def from_dlpack(tensor, assumed_align=None):

The assumed_align integer parameter specifies the alignment of the tensor in unit of bytes. The tensor’s base address must be divisible by assumed_align. When not provided explicitly, the alignment is set to the natural alignment of the tensor’s element type. Note that the alignment information is part of the pointer type in the generated IR. Therefore, programs with different alignments have a different IR and identical IRs are required for hitting the kernel caching mechanism of CuTe DSL.

Code Example#

The following code demonstrates how to convert a PyTorch tensor to a CuTe tensor using the from_dlpack function with default parameters.

import torch
import cutlass
from cutlass.cute.runtime import from_dlpack

x = torch.randn(30, 20, device="cpu")
y = from_dlpack(x)

Once converted, we can access the tensor’s information through various attributes. The following list shows the attributes of the converted tensor:

  • tensor.shape: the tensor’s shape

  • tensor.stride: the tensor’s stride

  • tensor.memspace: the tensor’s memory space

  • tensor.element_type: the tensor’s element data type

import torch
import cutlass
from cutlass.cute.runtime import from_dlpack

x = torch.randn(30, 20, device="cpu")
y = from_dlpack(x)

print(y.shape)        # (30, 20)
print(y.stride)       # (20, 1)
print(y.memspace)     # generic (if torch tensor in on device memory, memspace will be gmem)
print(y.element_type) # Float32
print(y)              # Tensor<0x000000000875f580@generic o (30, 20):(20, 1)>

The string format of the resulting CuTe tensor is

Tensor<0x{tensor.data_ptr:016x}@{tensor.memspace} o {tensor.shape}:{tensor.stride}>

As can be seen in the example above, from_dlpack first results in a tensor with a static layout. To obtain dynamic or mixed static/dynamic layouts after calling from_dlpack, the mark_layout_dynamic and mark_compact_shape_dynamic functions are used and described in the following sections.

When to Use Explicit Conversion?#

The DLPack protocol is a widely used protocol for interoperability between different frameworks. However, there is some associated overhead. Based on our benchmark, it usually takes between 2 to 3 us per call to from_dlpack.

Explicit conversion allows for caching the converted CuTe tensors in order to avoid the overhead of repeated calls to from_dlpack.

x = torch.randn(30, 20, device="cpu")
if key not in cached_tensors:
    # Do the conversion only for cache misses
    cached_tensors[key] = cute.runtime.from_dlpack(x)
foo(cached_tensors[key])

Another use case for explicit conversion is to gain fine-grain control over which modes of a tensor are considered dynamic from the perspective of the generated program.

Mark the Tensor’s Layout as Dynamic with mark_layout_dynamic#

After calling this function, all shape modes become dynamic. The stride modes also become dynamic with the following two exceptions:

  1. the leading dimension’s stride remains fixed at 1;

  2. stride elements equal to 0 (which indicates broadcasting) are retained.

The full signature of mark_layout_dynamic is as follows:

def mark_layout_dynamic(self, leading_dim: int|None = None):

The leading_dim parameter specifies the leading dimension of the tensor. The leading dimension’s stride is set to 1 unless inconsistent with the layout of the DLPack tensor. For example,

  • For a tensor with layout (2,2,3,4):(2,1,4,12), if leading_dim is specified to be 1, the layout will be marked as (?,?,?,?):(?,1,?,?).

  • If leading_dim is specified to be 0, a deduction failure error is raised because the stride of dimension 0 is 2 (not 1).

The default value for leading_dim is None. In such case, the system automatically deduces it from the tensor’s layout using the following logic:

  1. If a dimension’s stride is 1, that dimension is marked as the leading dimension.

  2. If multiple dimensions satisfy condition 1, an error is thrown indicating deduction failure. Note that after converting a PyTorch tensor to the DLPack format, the stride for dimensions with size 1 are canonicalized to 1. This canonicalization can increase the likelihood of deduction failures. This behavior is specific to PyTorch and does not occur with NumPy for example.

  3. If no dimension satisfies condition 1, all strides are marked as dynamic.

For example:

  • For a tensor with layout (2,2,3,4):(2,1,4,12), the leading dimension is 1. The layout will be marked as (?,?,?,?):(?,1,?,?).

  • For a tensor with layout (1,5,1):(1,1,1), if leading_dim is not specified, a deduction failure error is raised.

  • For a tensor with layout (2,2):(8,2), since no dimension has stride 1, all dimensions are marked as dynamic: (?,?):(?,?).

Code Example#

The following example demonstrates how to use mark_layout_dynamic to specify dynamic tensor layouts.

  • t0 shows the usage of mark_layout_dynamic with unspecified leading_dim and the automatic deduction of leading dimension.

  • t1 & t2 shows the usage of mark_layout_dynamic with specified leading_dim.

  • t3 shows the usage of mark_layout_dynamic with no leading dimension.

  • t4 shows the usage of mark_layout_dynamic with broadcasted dimensions.

  • t5 demonstrates the deduction failure when the there’re more than one dimensions with stride equals to 1.

  • t6 & t7 demonstrates incorrect settings for leading_dim and expected errors.

import torch
from cutlass.cute.runtime import from_dlpack

# (8,4,16,2):(2,16,64,1)
a = torch.empty(16, 4, 8, 2).permute(2, 1, 0, 3)
# (1,4,1,32,1):(4,1,4,4,4) => torch tensor when dimension has shape 1, its stride is degenerated to 1,
# resulting in (1,4,1,32,1):(1,1,1,4,1)
b = torch.empty(32, 1, 1, 1, 4).permute(3, 4, 1, 0, 2)
# (2,2):(8,2)
c = torch.empty(3, 4)[::2, ::2]
# (3,1,1,5):(5,0,0,1)
d = torch.empty(3, 1, 1, 5).expand(3, 4, 2, 5)

# auto deduce the leading dimension to be 3
t0 = from_dlpack(a).mark_layout_dynamic()
print(t0)
# (?,?,?,?):(?,?,?,1)

t1 = from_dlpack(b).mark_layout_dynamic(leading_dim=0)
print(t2)
# (?,?,?,?,?):(1,?,?,?,?)

t2 = from_dlpack(b).mark_layout_dynamic(leading_dim=2)
print(t3)
# (?,?,?,?,?):(?,?,1,?,?)

t3 = from_dlpack(c).mark_layout_dynamic()
print(t3)
# (?,?):(?,?)

t4 = from_dlpack(d).mark_layout_dynamic()
print(t4)
# (?,?,?,?):(?,0,0,1)

t5 = from_dlpack(b).mark_layout_dynamic()
# Can't decude the leading dimension from layout, please specify the leading_dim explicitly.

t6 = from_dlpack(a).mark_layout_dynamic(leading_dim=1)
# Expected strides[leading_dim] == 1, but got 16

t7 = from_dlpack(b).mark_layout_dynamic(leading_dim=3)
# Expected strides[leading_dim] == 1, but got 4

Mark the Tensor’s Layout as Dynamic with mark_compact_shape_dynamic#

The mark_compact_shape_dynamic function provides fine-grain control over dynamic shapes for compact layouts. The full signature of mark_compact_shape_dynamic is as follows:

def mark_compact_shape_dynamic(self, mode: int, stride_order: tuple[int, ...]|None = None, divisibility: int = 1):

The mode parameter determines which shape dimension becomes dynamic. After calling this function, the specific shape dimension given by mode is marked as dynamic immediately. The stride will be updated accordingly but this process is delayed until the C ABI of the tensor is constructed. For modes that have a shape of size 1, their stride are canonicalized to 0.

The stride_order parameter specifies the ordering of strides in the tensor. It is consistent with torch.Tensor.dim_order() and defaults to None. The parameter indicates the order of modes (dimensions) if the current layout were to be converted to row-major order. It starts from the outermost to the innermost dimension when reading it from left to right. This parameter must be explicitly set when the stride order cannot be automatically deduced from the tensor’s layout, such as when multiple dimensions have a stride of 1.

For example:

  • Layout (4,2):(1,4) has a stride_order of (1,0) indicates the innermost dimension is 0 (4:1), the outermost dimension is 1 (2:4).

  • Layout (5,3,2,4):(3,1,15,30) has a stride_order of (3,2,0,1) indicates the innermost dimension is 1 (3:1), the outermost dimension is 3 (4:30).

If stride_order is not specified, the system automatically deduces it from the tensor’s layout using the following logic:

  1. Sort the strides in descending order.

  2. If multiple dimensions have a stride of 1, a deduction failure error is raised.

For example:

  • For a tensor with layout (2,2,3,4):(2,1,4,12), the deduced stride_order is [3,2,0,1].

  • For a tensor with layout (1,5,1):(1,1,1), stride_order’s deduction fails because all dimensions have an identical stride of 1, making it impossible to determine the correct ordering.

If stride_order is specified, the system validates that the order is consistent with the tensor’s layout.

The divisibility parameter specifies the divisibility of the dynamic shape. It could be used to represent the assumption alignment of the input. Defaults to 1.

Note that this API is only available for compact tensors. For non-compact tensors, we can use cute.assume to attach divisibility information to a specific shape mode in a host JIT function, as demonstrated in the following example:

@cute.jit
def foo(a: cute.Tensor):
    new_shape = a.shape
    # use cute.assume to set shape of mode=0 with divisibility=16
    new_shape[0] = cute.assume(new_shape[0], 16)
    new_layout = cute.make_layout(new_shape, stride=a.stride)
    new_a = cute.make_tensor(a.iterator, new_layout)

Code Example#

The following example demonstrates how to use mark_compact_shape_dynamic to specify dynamic tensor layouts.

  • t0 & t1 show the usage of mark_compact_shape_dynamic with unspecified stride_order and different mode and divisibility.

  • t2 shows the usage of consecutive mark_compact_shape_dynamic with unspecified stride_order and different mode and divisibility.

  • t3 & t4 show the usage of mark_compact_shape_dynamic with different specified stride_order.

  • t5, t6, t7, t8, t9, t10, t11, and t12 demonstrate incorrect settings for parameters and expected errors.

import torch
from cutlass.cute.runtime import from_dlpack

@cute.jit
def kernel(t: cute.Tensor):
    pass

# (8,4,16,2):(2,16,64,1)
a = torch.empty(16, 4, 8, 2).permute(2, 1, 0, 3)
# (1,4,1,32,1):(4,1,4,4,4) => torch tensor when dimension has shape 1, its stride is degenerated to 1,
# resulting in (1,4,1,32,1):(1,1,1,4,1)
# b.dim_order() is (3,2,4,0,1)
b = torch.empty(32, 1, 1, 1, 4).permute(3, 4, 1, 0, 2)

# auto deduce the stride order to be [2,1,0,3]
t0 = from_dlpack(a).mark_compact_shape_dynamic(
    mode=0, divisibility=2
)
kernel(t0)
# (?{div=2},4,16,2):(2,?{div=4},?{div=16},1)
print(t0)

t1 = from_dlpack(a).mark_compact_shape_dynamic(
    mode=1, divisibility=2
)
kernel(t1)
# (8,?{div=2},16,2):(2,16,?{div=32},1)
print(t1)

t2 = from_dlpack(a).mark_compact_shape_dynamic(
    mode=1, divisibility=2
).mark_compact_shape_dynamic(
    mode=3, divisibility=2
)
kernel(t2)
# (8,?{div=2},16,?{div=2}):(?{div=2},?{div=16},?{div=32},1)
print(t2)

t3 = from_dlpack(b).mark_compact_shape_dynamic(
    mode=2, divisibility=1, stride_order=(3, 0, 2, 4, 1)
)
kernel(t3)
# (1,4,?,32,1):(0,1,4,?{div=4},0)
print(t3)

t4 = from_dlpack(b).mark_compact_shape_dynamic(
    mode=2, divisibility=1, stride_order=(2, 3, 4, 0, 1)
)
kernel(t4)
# (1,4,?,32,1):(0,1,128,4,0)
print(t4)

t5 = t2.mark_compact_shape_dynamic(
    mode=3, divisibility=5, stride_order=(0, 1, 2, 3)
)
# The stride_order is not consistent with the last stride_order

t6 = from_dlpack(a).mark_compact_shape_dynamic(
    mode=3, divisibility=5, stride_order=(0, 1, 2, 3)
)
# The stride_order is not consistent with the deduced stride_order

t7 = from_dlpack(b).mark_compact_shape_dynamic(
    mode=0, divisibility=4
)
# The layout could not be deduced, please specify the stride_order explicitly

t8 = from_dlpack(b).mark_compact_shape_dynamic(
    mode=30, divisibility=5, stride_order=(3, 0, 2, 4, 1)
)
# Expected mode value to be in range [0, 5), but got 30

t9 = from_dlpack(b).mark_compact_shape_dynamic(
    mode=3, divisibility=5, stride_order=(2, 1, 2, 3, 4)
)
# Expected stride_order to contain all the dimensions of the tensor, but it doesn't contain 0.

t10 = from_dlpack(b).mark_compact_shape_dynamic(
    mode=3, divisibility=5, stride_order=(0, 1, 2, 3, 4, 5)
)
# Expected stride_order to have 5 elements, but got 6.

t11 = from_dlpack(b).mark_compact_shape_dynamic(
    mode=0, divisibility=4, stride_order=b.dim_order()
)
# The shape(1) of mode(0) is not divisible by the divisibility(4)

t12 = from_dlpack(b).mark_compact_shape_dynamic(
    mode=0, divisibility=1, stride_order=(2, 1, 3, 0, 4)
)
# The stride_order is not consistent with the layout