Integration with Frameworks#
In order to facilitate the integration of CUTLASS Python with popular frameworks, we leverage the DLPack protocol and transform tensors originating from these frameworks to CuTe tensors. The present page documents the conventions, the API available to the user, and provide example code snippets for common usage patterns.
Implicit Conversion#
Tensors originating from frameworks supporting the DLPack protocol can be directly provided to a JIT function as a regular parameter. CuTe DSL’s runtime implicitly converts the original tensor to a CuTe tensor with a fully dynamic layout except for the stride element corresponding to the leading dimension. The example below demonstrates this use case.
import torch
import cutlass.cute as cute
@cute.jit
def foo(src):
"""
The following lines print
ptr<f32, generic> o (?,?,?):(?,?,1)
<class 'cutlass.cute.core._Tensor'>
"""
print(src)
print(type(src))
a = torch.randn(30, 20, 32, device="cpu")
foo(a)
Explicit conversion using from_dlpack
#
CuTe DSL’s runtime provides an interface for converting DLPack-compatible tensors to CuTe tensors,
b = cute.runtime.from_dlpack(a)
where a
is a tensor supporting the DLPack protocol with the __dlpack__
and __dlpack_device__
methods. The resulting CuTe tensor b
has a fully static layout. This
conversion is performed without copying any tensor data, enabling seamless integration with major
frameworks. Users can create tensors using NumPy, PyTorch, etc. and directly feed them into JIT
functions writtnen using CuTe DSL.
The resulting CuTe tensor shares the same underlying memory buffer as the original tensor. This zero-copy approach maximizes performance by eliminating unnecessary data duplication. However, it is important to note that the CuTe tensor’s validity is tied to the lifetime of the original tensor. If the source tensor is destroyed or goes out of scope, the corresponding CuTe tensor becomes invalid since it references the original memory location.
The full signature of from_dlpack is as follows:
def from_dlpack(tensor, assumed_align=None):
The assumed_align
integer parameter specifies the alignment of the tensor in unit of bytes.
The tensor’s base address must be divisible by assumed_align
. When not provided explicitly,
the alignment is set to the natural alignment of the tensor’s element type. Note that the alignment
information is part of the pointer type in the generated IR. Therefore, programs with different
alignments have a different IR and identical IRs are required for hitting the kernel caching
mechanism of CuTe DSL.
Code Example#
The following code demonstrates how to convert a PyTorch tensor to a CuTe tensor using the
from_dlpack
function with default parameters.
import torch
import cutlass
from cutlass.cute.runtime import from_dlpack
x = torch.randn(30, 20, device="cpu")
y = from_dlpack(x)
Once converted, we can access the tensor’s information through various attributes. The following list shows the attributes of the converted tensor:
tensor.shape
: the tensor’s shapetensor.stride
: the tensor’s stridetensor.memspace
: the tensor’s memory spacetensor.element_type
: the tensor’s element data type
import torch
import cutlass
from cutlass.cute.runtime import from_dlpack
x = torch.randn(30, 20, device="cpu")
y = from_dlpack(x)
print(y.shape) # (30, 20)
print(y.stride) # (20, 1)
print(y.memspace) # generic (if torch tensor in on device memory, memspace will be gmem)
print(y.element_type) # Float32
print(y) # Tensor<0x000000000875f580@generic o (30, 20):(20, 1)>
The string format of the resulting CuTe tensor is
Tensor<0x{tensor.data_ptr:016x}@{tensor.memspace} o {tensor.shape}:{tensor.stride}>
As can be seen in the example above, from_dlpack
first results in a tensor with a static layout.
To obtain dynamic or mixed static/dynamic layouts after calling from_dlpack
, the
mark_layout_dynamic
and mark_compact_shape_dynamic
functions are used and described in
the following sections.
When to Use Explicit Conversion?#
The DLPack protocol is a widely used protocol for interoperability between different frameworks.
However, there is some associated overhead. Based on our benchmark, it usually takes between 2 to 3
us per call to from_dlpack
.
Explicit conversion allows for caching the converted CuTe tensors in order to avoid the overhead of
repeated calls to from_dlpack
.
x = torch.randn(30, 20, device="cpu")
if key not in cached_tensors:
# Do the conversion only for cache misses
cached_tensors[key] = cute.runtime.from_dlpack(x)
foo(cached_tensors[key])
Another use case for explicit conversion is to gain fine-grain control over which modes of a tensor are considered dynamic from the perspective of the generated program.
Mark the Tensor’s Layout as Dynamic with mark_layout_dynamic
#
After calling this function, all shape modes become dynamic. The stride modes also become dynamic with the following two exceptions:
the leading dimension’s stride remains fixed at 1;
stride elements equal to 0 (which indicates broadcasting) are retained.
The full signature of mark_layout_dynamic
is as follows:
def mark_layout_dynamic(self, leading_dim: int|None = None):
The leading_dim
parameter specifies the leading dimension of the tensor. The leading dimension’s
stride is set to 1 unless inconsistent with the layout of the DLPack tensor. For example,
For a tensor with layout
(2,2,3,4):(2,1,4,12)
, ifleading_dim
is specified to be 1, the layout will be marked as(?,?,?,?):(?,1,?,?)
.If
leading_dim
is specified to be 0, a deduction failure error is raised because the stride of dimension 0 is 2 (not 1).
The default value for leading_dim
is None
. In such case, the system
automatically deduces it from the tensor’s layout using the following logic:
If a dimension’s stride is 1, that dimension is marked as the leading dimension.
If multiple dimensions satisfy condition 1, an error is thrown indicating deduction failure. Note that after converting a PyTorch tensor to the DLPack format, the stride for dimensions with size 1 are canonicalized to 1. This canonicalization can increase the likelihood of deduction failures. This behavior is specific to PyTorch and does not occur with NumPy for example.
If no dimension satisfies condition 1, all strides are marked as dynamic.
For example:
For a tensor with layout
(2,2,3,4):(2,1,4,12)
, the leading dimension is 1. The layout will be marked as(?,?,?,?):(?,1,?,?)
.For a tensor with layout
(1,5,1):(1,1,1)
, ifleading_dim
is not specified, a deduction failure error is raised.For a tensor with layout
(2,2):(8,2)
, since no dimension has stride 1, all dimensions are marked as dynamic:(?,?):(?,?)
.
Code Example#
The following example demonstrates how to use mark_layout_dynamic
to specify dynamic tensor layouts.
t0
shows the usage ofmark_layout_dynamic
with unspecifiedleading_dim
and the automatic deduction of leading dimension.t1
&t2
shows the usage ofmark_layout_dynamic
with specifiedleading_dim
.t3
shows the usage ofmark_layout_dynamic
with no leading dimension.t4
shows the usage ofmark_layout_dynamic
with broadcasted dimensions.t5
demonstrates the deduction failure when the there’re more than one dimensions with stride equals to 1.t6
&t7
demonstrates incorrect settings forleading_dim
and expected errors.
import torch
from cutlass.cute.runtime import from_dlpack
# (8,4,16,2):(2,16,64,1)
a = torch.empty(16, 4, 8, 2).permute(2, 1, 0, 3)
# (1,4,1,32,1):(4,1,4,4,4) => torch tensor when dimension has shape 1, its stride is degenerated to 1,
# resulting in (1,4,1,32,1):(1,1,1,4,1)
b = torch.empty(32, 1, 1, 1, 4).permute(3, 4, 1, 0, 2)
# (2,2):(8,2)
c = torch.empty(3, 4)[::2, ::2]
# (3,1,1,5):(5,0,0,1)
d = torch.empty(3, 1, 1, 5).expand(3, 4, 2, 5)
# auto deduce the leading dimension to be 3
t0 = from_dlpack(a).mark_layout_dynamic()
print(t0)
# (?,?,?,?):(?,?,?,1)
t1 = from_dlpack(b).mark_layout_dynamic(leading_dim=0)
print(t2)
# (?,?,?,?,?):(1,?,?,?,?)
t2 = from_dlpack(b).mark_layout_dynamic(leading_dim=2)
print(t3)
# (?,?,?,?,?):(?,?,1,?,?)
t3 = from_dlpack(c).mark_layout_dynamic()
print(t3)
# (?,?):(?,?)
t4 = from_dlpack(d).mark_layout_dynamic()
print(t4)
# (?,?,?,?):(?,0,0,1)
t5 = from_dlpack(b).mark_layout_dynamic()
# Can't decude the leading dimension from layout, please specify the leading_dim explicitly.
t6 = from_dlpack(a).mark_layout_dynamic(leading_dim=1)
# Expected strides[leading_dim] == 1, but got 16
t7 = from_dlpack(b).mark_layout_dynamic(leading_dim=3)
# Expected strides[leading_dim] == 1, but got 4
Mark the Tensor’s Layout as Dynamic with mark_compact_shape_dynamic
#
The mark_compact_shape_dynamic
function provides fine-grain control over dynamic shapes for compact
layouts. The full signature of mark_compact_shape_dynamic
is as follows:
def mark_compact_shape_dynamic(self, mode: int, stride_order: tuple[int, ...]|None = None, divisibility: int = 1):
The mode
parameter determines which shape dimension becomes dynamic. After calling this function,
the specific shape dimension given by mode
is marked as dynamic immediately. The stride will be
updated accordingly but this process is delayed until the C ABI of the tensor is constructed.
For modes that have a shape of size 1, their stride are canonicalized to 0.
The stride_order
parameter specifies the ordering of strides in the tensor. It is consistent
with torch.Tensor.dim_order()
and defaults to None
. The parameter indicates the order of
modes (dimensions) if the current layout were to be converted to row-major order. It starts from the
outermost to the innermost dimension when reading it from left to right. This parameter must be
explicitly set when the stride order cannot be automatically deduced from the tensor’s layout, such
as when multiple dimensions have a stride of 1.
For example:
Layout
(4,2):(1,4)
has astride_order
of(1,0)
indicates the innermost dimension is 0 (4:1
), the outermost dimension is 1 (2:4
).Layout
(5,3,2,4):(3,1,15,30)
has astride_order
of(3,2,0,1)
indicates the innermost dimension is 1 (3:1
), the outermost dimension is 3 (4:30
).
If stride_order
is not specified, the system automatically deduces it from the tensor’s layout
using the following logic:
Sort the strides in descending order.
If multiple dimensions have a stride of 1, a deduction failure error is raised.
For example:
For a tensor with layout
(2,2,3,4):(2,1,4,12)
, the deducedstride_order
is[3,2,0,1]
.For a tensor with layout
(1,5,1):(1,1,1)
,stride_order
’s deduction fails because all dimensions have an identical stride of 1, making it impossible to determine the correct ordering.
If stride_order
is specified, the system validates that the order is consistent with the
tensor’s layout.
The divisibility
parameter specifies the divisibility of the dynamic shape. It could be used to
represent the assumption alignment of the input. Defaults to 1.
Note that this API is only available for compact tensors. For non-compact tensors, we can use
cute.assume
to attach divisibility information to a specific shape mode in a host JIT function,
as demonstrated in the following example:
@cute.jit
def foo(a: cute.Tensor):
new_shape = a.shape
# use cute.assume to set shape of mode=0 with divisibility=16
new_shape[0] = cute.assume(new_shape[0], 16)
new_layout = cute.make_layout(new_shape, stride=a.stride)
new_a = cute.make_tensor(a.iterator, new_layout)
Code Example#
The following example demonstrates how to use mark_compact_shape_dynamic
to specify dynamic tensor layouts.
t0
&t1
show the usage ofmark_compact_shape_dynamic
with unspecifiedstride_order
and differentmode
anddivisibility
.t2
shows the usage of consecutivemark_compact_shape_dynamic
with unspecifiedstride_order
and differentmode
anddivisibility
.t3
&t4
show the usage ofmark_compact_shape_dynamic
with different specifiedstride_order
.t5
,t6
,t7
,t8
,t9
,t10
,t11
, andt12
demonstrate incorrect settings for parameters and expected errors.
import torch
from cutlass.cute.runtime import from_dlpack
@cute.jit
def kernel(t: cute.Tensor):
pass
# (8,4,16,2):(2,16,64,1)
a = torch.empty(16, 4, 8, 2).permute(2, 1, 0, 3)
# (1,4,1,32,1):(4,1,4,4,4) => torch tensor when dimension has shape 1, its stride is degenerated to 1,
# resulting in (1,4,1,32,1):(1,1,1,4,1)
# b.dim_order() is (3,2,4,0,1)
b = torch.empty(32, 1, 1, 1, 4).permute(3, 4, 1, 0, 2)
# auto deduce the stride order to be [2,1,0,3]
t0 = from_dlpack(a).mark_compact_shape_dynamic(
mode=0, divisibility=2
)
kernel(t0)
# (?{div=2},4,16,2):(2,?{div=4},?{div=16},1)
print(t0)
t1 = from_dlpack(a).mark_compact_shape_dynamic(
mode=1, divisibility=2
)
kernel(t1)
# (8,?{div=2},16,2):(2,16,?{div=32},1)
print(t1)
t2 = from_dlpack(a).mark_compact_shape_dynamic(
mode=1, divisibility=2
).mark_compact_shape_dynamic(
mode=3, divisibility=2
)
kernel(t2)
# (8,?{div=2},16,?{div=2}):(?{div=2},?{div=16},?{div=32},1)
print(t2)
t3 = from_dlpack(b).mark_compact_shape_dynamic(
mode=2, divisibility=1, stride_order=(3, 0, 2, 4, 1)
)
kernel(t3)
# (1,4,?,32,1):(0,1,4,?{div=4},0)
print(t3)
t4 = from_dlpack(b).mark_compact_shape_dynamic(
mode=2, divisibility=1, stride_order=(2, 3, 4, 0, 1)
)
kernel(t4)
# (1,4,?,32,1):(0,1,128,4,0)
print(t4)
t5 = t2.mark_compact_shape_dynamic(
mode=3, divisibility=5, stride_order=(0, 1, 2, 3)
)
# The stride_order is not consistent with the last stride_order
t6 = from_dlpack(a).mark_compact_shape_dynamic(
mode=3, divisibility=5, stride_order=(0, 1, 2, 3)
)
# The stride_order is not consistent with the deduced stride_order
t7 = from_dlpack(b).mark_compact_shape_dynamic(
mode=0, divisibility=4
)
# The layout could not be deduced, please specify the stride_order explicitly
t8 = from_dlpack(b).mark_compact_shape_dynamic(
mode=30, divisibility=5, stride_order=(3, 0, 2, 4, 1)
)
# Expected mode value to be in range [0, 5), but got 30
t9 = from_dlpack(b).mark_compact_shape_dynamic(
mode=3, divisibility=5, stride_order=(2, 1, 2, 3, 4)
)
# Expected stride_order to contain all the dimensions of the tensor, but it doesn't contain 0.
t10 = from_dlpack(b).mark_compact_shape_dynamic(
mode=3, divisibility=5, stride_order=(0, 1, 2, 3, 4, 5)
)
# Expected stride_order to have 5 elements, but got 6.
t11 = from_dlpack(b).mark_compact_shape_dynamic(
mode=0, divisibility=4, stride_order=b.dim_order()
)
# The shape(1) of mode(0) is not divisible by the divisibility(4)
t12 = from_dlpack(b).mark_compact_shape_dynamic(
mode=0, divisibility=1, stride_order=(2, 1, 3, 0, 4)
)
# The stride_order is not consistent with the layout