Static vs Dynamic layouts#

Static Layout#

When integrating with popular deep learning frameworks, one question is how to deal with the layout of the converted cute.Tensor. For example, when converting a torch.Tensor to a cute.Tensor, the shape of the torch.Tensor is honored for the layout of cute.Tensor.

import torch
import cutlass
from cutlass.cute.runtime import from_dlpack

@cute.jit
def foo(tensor):
    print(f"tensor.layout: {tensor.layout}")  # Prints tensor layout at compile time
    cute.printf("tensor: {}", tensor)         # Prints tensor values at runtime

In this example, we define a JIT function foo that takes a cute.Tensor as input and prints its layout. Note that Python print is used to print the layout at compile time. This works fine for static layout whose value is known at compile time.

Now let’s try to run the JIT function foo with different shapes of the input torch.Tensor.

a = torch.tensor([1, 2, 3], dtype=torch.uint16)
a_pack = from_dlpack(a)
compiled_func = cute.compile(foo, a_pack)
compiled_func(a_pack)

Here we first convert a 1D torch.Tensor with 3 elements to a cute.Tensor using from_dlpack. Then we compile the JIT function foo with the converted cute.Tensor and call the compiled function.

  tensor.layout: (3):(1)
  tensor: raw_ptr(0x00000000079e5100: i16, generic, align<2>) o (3):(1) =
( 1, 2, 3 )

It prints (3):(1) for the layout because the converted cute.Tensor has a static layout with shape (3) which is the shape of the a.

Now if we call the compiled function with a different shape of the input torch.Tensor, it would result in an unexpected result at runtime due to the mismatch of the type since compiled_func expects a cute.Tensor with layout (3):(1) while b has shape (5).

b = torch.tensor([11, 12, 13, 14, 15], dtype=torch.uint16)
b_pack = from_dlpack(b)
compiled_func(b_pack)  # ❌ This results in an unexpected result at runtime due to type mismatch

Following is the output which is unexpected due to the type mismatch.

  tensor: raw_ptr(0x00000000344804c0: i16, generic, align<2>) o (3):(1) =
( 11, 12, 13 )

To fix that, we would have to trigger another code generation and compilation for the new shape for b.

compiled_func_2 = cute.compile(foo, b_pack)  # This would trigger another compilation
compiled_func_2(b_pack)                      # ✅ Now this works fine

As shown in the example above, with the newly compiled compiled_func_2, we can pass in b_pack to the compiled JIT function compiled_func_2.

  tensor.layout: (5):(1)
  tensor: raw_ptr(0x0000000034bb2840:: i16, generic, align<2>) o (5):(1) =
( 11, 12, 13, 14, 15 )

Now it recompiles and prints the values of b correctly.

It’s obvoius that we need distinct codes generated and compiled for different static layout. In this case, one for layout (3):(1) and the other for layout (5):(1).

Dynamic Layout#

In order to avoid generating and compiling multiple times for different shapes of the input torch.Tensor, CuTe DSL provides a way to generate and compile JIT function with dynamic layout.

To get dyanmic layout of the cute.Tensor, a torch.Tensor object can be passed into the JIT function directly which instructs CuTe DSL to call cute.mark_layout_dynamic automatically on the converted cute.Tensor per the leading dimension of the layout.

import torch
import cutlass
from cutlass.cute.runtime import from_dlpack

@cute.jit
def foo(tensor):
    print(tensor.layout)  # Prints (?,?):(?,1) for dynamic layout

a = torch.tensor([[1, 2], [3, 4]], dtype=torch.uint16)
compiled_func = cute.compile(foo, a)
compiled_func(a)

b = torch.tensor([[11, 12], [13, 14], [15, 16]], dtype=torch.uint16)
compiled_func(b)  # Reuse the same compiled function for different shape

In the example above, a single compilation of the JIT function foo is reused for different shapes of the input torch.Tensor. This is possible because the converted cute.Tensor has a dynamic layout (?,?):(?,1) which is compatible with the shape of the input torch.Tensor of both calls.

Alternatively, for compact layout, cute.mark_compact_shape_dynamic can be called for a finer-grained control to specify the mode of the layout for dynamic and the divisibility constraint for the dynamic dimension.

Refer to Integration with Frameworks for more details on from_dlpack, mark_layout_dynamic, and mark_compact_shape_dynamic.

Static Layout vs. Dynamic Layout#

Per the previous sections, we have seen that static layout leads to distinct JIT code generations while dynamic layout leads to a single compilation for different shapes.

That said, creating JIT function with static layout is useful when the use cases targeting input data with fixed shapes. Since more information is available at compile time, the compiler would be able to kick in optimizations that otherwise would not be possible for the code generated for dynamic layout.

On the other hand, dynamic layout would be more flexible for the cases where the input data has varying shapes. This provides more scalability of the generated code to deal with varying input data of different shapes.

Programming with Static and Dynamic Layout#

CuTe DSL provides intuitive way to program with static and dynamic layout in the codes.

import torch
import cutlass
from cutlass.cute.runtime import from_dlpack

@cute.jit
def foo(tensor, x: cutlass.Constexpr[int]):
    print(cute.size(tensor))  # Prints 3 for the 1st call
                              # Prints ? for the 2nd call
    if cute.size(tensor) > x:
        cute.printf("tensor[2]: {}", tensor[2])
    else:
        cute.printf("tensor size <= {}", x)

a = torch.tensor([1, 2, 3], dtype=torch.uint16)
foo(from_dlpack(a), 3)   # First call with static layout

b = torch.tensor([1, 2, 3, 4, 5], dtype=torch.uint16)
foo(b, 3)                # Second call with dynamic layout

In this example, the JIT function foo is compiled with a static layout (3):(1) for the first call, which means the size of the tensor is known at compile time. CuTe DSL makes good use of this and automatically handles the if condition at the compile time. Hence the generated codes are efficient without the if condition at all.

For the second call, the JIT function foo is compiled with a dynamic layout (?):(1) hence the tensor size is only evaluated at runtime. CuTe DSL automatically generates the code to handle the dynamic layout and the if condition at runtime.

The same applies to loop as well:

@cute.jit
def foo(tensor, x: cutlass.Constexpr[int]):
    for i in range(cute.size(tensor)):
        cute.printf("tensor[{}]: {}", i, tensor[i])

a = torch.tensor([1, 2, 3], dtype=torch.uint16)
foo(from_dlpack(a), 3)   # First call with static layout

b = torch.tensor([1, 2, 3, 4, 5], dtype=torch.uint16)
foo(b, 3)                # Second call with dynamic layout

With the static layout in the first call, CuTe DSL is able to fully unroll the loop at compile time. While in the second call, the generated codes will have the loop executed at runtime based on the dynamic layout.

With the single JIT function implementation, CuTe DSL is able to handle control-flow constructs and automatically generate the optimized codes for different cases. This is all possible because CuTe DSL is able to walk the Python AST and convert each control-flow construct it finds accordingly.

Please refer to Control Flow for more details.