cutlass.utils#

The cutlass.utils module contains utilities for developing kernels with CuTe DSL.

cutlass.utils.get_smem_capacity_in_bytes(compute_capability: str) int#

Get the shared memory capacity in bytes for a given compute capability.

Returns the maximum shared memory capacity in bytes available for the specified GPU compute capability.

Parameters:

compute_capability (str) – The compute capability string (e.g. “70”, “75”, “80”)

Returns:

The shared memory capacity in bytes

Return type:

int

Raises:

ValueError – If the compute capability is not supported

class cutlass.utils.SmemAllocator#

Bases: object

A helper class for managing shared memory allocation on GPU.

This class manages shared memory and provides APIs for allocation of raw bytes, numeric types, arrays, and tensors with specified layouts and alignments.

Note

  • The base pointer is aligned to 1024 bytes upon initialization.

  • There is no need to explicitly specify shared memory size in kernel launch.

  • Currently only supports static layouts. Dynamic layouts are not supported.

Examples:

smem = SmemAllocator()

# Allocate raw bytes
buf_ptr = smem.allocate(100)  # 100 bytes

# Allocate numeric type
int8_ptr = smem.allocate(Int8)  # 1 byte

# Define a struct
@cute.struct
class SharedStorage:
    alpha: cutlass.Float32
    x: cutlass.Int32

# Allocate struct
struct_ptr = smem.allocate(SharedStorage)  # 8 bytes

# use of struct members
struct_ptr.alpha = 1.0
struct_ptr.x = 2

# Allocate array
int8_array = smem.allocate_array(Int8, 10)  # 10 bytes

# Allocate tensor
layout = cute.make_layout((16, 16))
tensor = smem.allocate_tensor(Int8, layout)  # 256 bytes
static capacity_in_bytes(compute_capability: str) int#

Get the shared memory capacity in bytes for a given compute capability.

Returns the maximum shared memory capacity in bytes available for the specified GPU compute capability.

Parameters:

compute_capability (str) – The compute capability string (e.g. “70”, “75”, “80”)

Returns:

The shared memory capacity in bytes

Return type:

int

Raises:

ValueError – If the compute capability is not supported

__init__(*, loc=None, ip=None)#

Initialize a new SmemAllocator instance.

Creates a new shared memory allocator with a base pointer aligned to 1024 bytes. Tracks the allocator instance for memory management.

Parameters:
  • loc (Optional[ir.Location]) – Source location information for debugging, defaults to None

  • ip (Optional[ir.InsertionPoint]) – Insertion point for MLIR operations, defaults to None

allocate(
size_or_type: int,
byte_alignment: int,
*,
loc=None,
ip=None,
) cutlass.cute.typing.Pointer#
allocate(
size_or_type: Type[cutlass.cutlass_dsl.Numeric],
byte_alignment: int,
*,
loc=None,
ip=None,
) cutlass.cute.typing.Pointer
allocate(
size_or_type: struct,
byte_alignment: int,
*,
loc=None,
ip=None,
) cutlass.cute.typing.Pointer

Allocate a block of memory with specified size and alignment.

This method allocates a block of shared memory with the specified size and alignment requirements. It supports allocating raw bytes, numeric types(as scalar value), and struct types.

Parameters:
  • size_or_type (Union[int, Type[Numeric], cute.struct]) – The allocation specification, which can be: - An integer specifying the number of bytes to allocate - A Numeric type (e.g., Int8, Float32) to allocate space for one element - A struct type to allocate space for the entire struct

  • byte_alignment (int, optional) – The minimum byte alignment requirement for the allocation, defaults to 1

  • loc (Optional[ir.Location]) – Source location information for debugging, defaults to None

  • ip (Optional[ir.InsertionPoint]) – Insertion point for MLIR operations, defaults to None

Returns:

For raw bytes and numeric types, returns a pointer to the allocated memory. For struct types, returns an initialized struct instance at the allocated location.

Return type:

cute.Pointer

Raises:
  • ValueError – If size is negative or alignment is less than 1

  • TypeError – If size_or_type is not an integer, Numeric type, or struct

  • RuntimeError – If allocation would exceed available shared memory

allocate_array(
element_type: Type[cutlass.cutlass_dsl.Numeric],
num_elems: int = 1,
*,
loc=None,
ip=None,
)#

Allocate an array of elements in shared memory.

Parameters:
  • element_type (Type[Numeric]) – The type of elements to allocate

  • num_elems (int, optional) – Number of elements to allocate, defaults to 1

Returns:

Pointer to the start of the allocated array

Return type:

cute.Pointer

Raises:
  • ValueError – If num_elems is less than 1

  • TypeError – If element_type is not a Numeric type

allocate_tensor(
element_type: Type[cutlass.cutlass_dsl.Numeric],
layout: int | cutlass.cute.typing.Layout | cutlass.cute.typing.ComposedLayout,
byte_alignment: int = 1,
swizzle: cutlass._mlir.ir.register_value_caster | None = None,
*,
loc=None,
ip=None,
)#

Allocate a tensor in shared memory.

Note: Currently only supports static layouts. Dynamic layouts are not supported.

Parameters:
  • element_type (Type[Numeric]) – The type of elements in the tensor

  • layout (Union[int, cute.Layout, cute.ComposedLayout]) – The layout specification for the tensor. Must be a static layout.

  • byte_alignment (int, optional) – The byte alignment requirement, defaults to 1

  • swizzle (cute.Swizzle, optional) – Swizzle for position-dependent swizzling, defaults to None

Returns:

The allocated tensor with specified properties

Return type:

cute.Tensor

Raises:
  • TypeError – If element_type is not a Numeric type or if swizzle conflicts with layout

  • ValueError – If allocation is not byte-aligned

  • NotImplementedError – If dynamic layout is specified

class cutlass.utils.TmemAllocator(
alloc_result_dst_smem_ptr: cutlass.cute.typing.Pointer,
barrier_for_retrieve: NamedBarrier,
allocator_warp_id: int = 0,
is_two_cta: bool = False,
num_allocated_columns: int = 0,
two_cta_tmem_dealloc_mbar_ptr: cutlass.cute.typing.Pointer | None = None,
)#

Bases: object

A class for managing tensor memory allocation on Blackwell GPU.

This class manages allocation/deallocation of tensor memory, including the mbarrier synchronization for two cta use case.

Variables:
  • _alloc_result_dst_smem_ptr – The smem pointer that holds the base address of allocated tensor memory.

  • _barrier_for_retrieve – The barrier for retrieving tensor memory ptr.

  • _allocator_warp_id – The warp id of the allocator warp.

  • _is_two_cta – Whether the allocator is for two cta.

  • _num_allocated_columns – The number of columns allocated in the tensor memory.

  • _two_cta_tmem_dealloc_mbar_ptr – The mbarrier pointer required when deallocating tensor memory for two cta.

_init_dealloc_mbarrier()#
__init__(
alloc_result_dst_smem_ptr: cutlass.cute.typing.Pointer,
barrier_for_retrieve: NamedBarrier,
allocator_warp_id: int = 0,
is_two_cta: bool = False,
num_allocated_columns: int = 0,
two_cta_tmem_dealloc_mbar_ptr: cutlass.cute.typing.Pointer | None = None,
)#

Initialize the TmemAllocator instance.

Sets up the allocator state by initializing smem pointer that holds the base address of allocated tensor memory, allocator warp id, whether it is for two cta, number of allocated columns, and barrier for retrieving tensor memory ptr. Meanwhile, it also initializes the mbarrier pointer for two cta deallocation case.

check_valid_num_columns(num_columns: int)#

Check if the number of columns is valid.

This method checks if the number of columns is valid. It checks if the number of columns is larger than 0, smaller than 512, a multiple of 32, and a power of two.

allocate(num_columns: int)#

Allocate a block of tensor memory.

This method allocates a block of tensor memory from allocator warp and returns a handle to retrieve the allocated tensor memory address.

wait_for_alloc()#

Wait for the allocator warp to finish allocation.

This method is used to synchronize the allocator warp with the other warps before retrieving tmem ptr.

retrieve_ptr(
dtype: Type[cutlass.cutlass_dsl.Numeric] = cutlass.cutlass_dsl.Float32,
) cutlass.cute.typing.Pointer#

Retrieve the pointer to the allocated tensor memory.

This method can be called by all warps after allocation has been performed by the allocator warp.

relinquish_alloc_permit()#

Relinquish the tensor memory allocation permit.

This method relinquishes the tensor memory allocation permit for the allocator warp, promising the allocator warp will not allocate any more tensor memory.

free(
tmem_ptr: cutlass.cute.typing.Pointer,
num_columns: int = 0,
)#

Deallocate the tensor memory.

This method sync on mbarrier (for two cta use case) and deallocates the tensor memory from the allocator warp. User can optionally specify the number of columns to deallocate. If not specified, all allocated columns will be deallocated.

class cutlass.utils.LayoutEnum(value)#

Bases: Enum

An enumeration.

ROW_MAJOR = 'row_major'#
COL_MAJOR = 'col_major'#
mma_major_mode()#
sm90_mma_major_mode()#
is_k_major_a()#
is_m_major_a()#
is_n_major_b()#
is_k_major_b()#
is_n_major_c()#
is_m_major_c()#
static from_tensor(
tensor: cutlass.cute.typing.Tensor,
) LayoutEnum#
class cutlass.utils.WorkTileInfo(
tile_idx: cutlass.cute.typing.Coord,
is_valid_tile: cutlass.cutlass_dsl.Boolean,
)#

Bases: object

A class to represent information about a work tile.

Variables:
  • tile_idx – The index of the tile.

  • is_valid_tile – Whether the tile is valid.

__init__(
tile_idx: cutlass.cute.typing.Coord,
is_valid_tile: cutlass.cutlass_dsl.Boolean,
)#
property is_valid_tile: cutlass.cutlass_dsl.Boolean#

Check latest tile returned by the scheduler is valid or not. Any scheduling requests after all tasks completed will return an invalid tile.

Returns:

The validity of the tile.

Return type:

Boolean

property tile_idx: cutlass.cute.typing.Coord#

Get the index of the tile.

Returns:

The index of the tile.

Return type:

cute.Coord

class cutlass.utils.PersistentTileSchedulerParams(
problem_shape_ntile_mnl: cutlass.cute.typing.Shape,
cluster_shape_mnk: cutlass.cute.typing.Shape,
swizzle_size: int = 1,
raster_along_m: bool = True,
*,
loc=None,
ip=None,
)#

Bases: object

A class to represent parameters for a persistent tile scheduler.

This class is designed to manage and compute the layout of clusters and tiles in a batched gemm problem.

Variables:
  • cluster_shape_mn – Shape of the cluster in (m, n) dimensions (K dimension cta count must be 1).

  • problem_layout_ncluster_mnl – Layout of the problem in terms of number of clusters in (m, n, l) dimensions.

__init__(
problem_shape_ntile_mnl: cutlass.cute.typing.Shape,
cluster_shape_mnk: cutlass.cute.typing.Shape,
swizzle_size: int = 1,
raster_along_m: bool = True,
*,
loc=None,
ip=None,
)#

Initializes the PersistentTileSchedulerParams with the given parameters.

Parameters:
  • problem_shape_ntile_mnl (cute.Shape) – The shape of the problem in terms of number of CTA (Cooperative Thread Array) in (m, n, l) dimensions.

  • cluster_shape_mnk (cute.Shape) – The shape of the cluster in (m, n) dimensions.

  • swizzle_size (int) – Swizzling size in the unit of cluster. 1 means no swizzle

  • raster_along_m (bool) – Rasterization order of clusters. Only used when swizzle_size > 1. True means along M, false means along N.

Raises:

ValueError – If cluster_shape_k is not 1.

get_grid_shape(
max_active_clusters: cutlass.cutlass_dsl.Int32,
*,
loc=None,
ip=None,
) Tuple[cutlass.cutlass_dsl.Integer, cutlass.cutlass_dsl.Integer, cutlass.cutlass_dsl.Integer]#

Computes the grid shape based on the maximum active clusters allowed.

Parameters:

max_active_clusters (Int32) – The maximum number of active clusters that can run in one wave.

Returns:

A tuple containing the grid shape in (m, n, persistent_clusters). - m: self.cluster_shape_m. - n: self.cluster_shape_n. - persistent_clusters: Number of persistent clusters that can run.

class cutlass.utils.StaticPersistentTileScheduler(
params: PersistentTileSchedulerParams,
num_persistent_clusters: cutlass.cutlass_dsl.Int32,
current_work_linear_idx: cutlass.cutlass_dsl.Int32,
cta_id_in_cluster: cutlass.cute.typing.Coord,
num_tiles_executed: cutlass.cutlass_dsl.Int32,
)#

Bases: object

A scheduler for static persistent tile execution in CUTLASS/CuTe kernels.

Variables:
  • params – Tile schedule related params, including cluster shape and problem_layout_ncluster_mnl

  • num_persistent_clusters – Number of persistent clusters that can be launched

  • cta_id_in_cluster – ID of the CTA within its cluster

  • _num_tiles_executed – Counter for executed tiles

  • _current_work_linear_idx – Current cluster index

__init__(
params: PersistentTileSchedulerParams,
num_persistent_clusters: cutlass.cutlass_dsl.Int32,
current_work_linear_idx: cutlass.cutlass_dsl.Int32,
cta_id_in_cluster: cutlass.cute.typing.Coord,
num_tiles_executed: cutlass.cutlass_dsl.Int32,
)#

Initializes the StaticPersistentTileScheduler with the given parameters.

Parameters:
  • params (PersistentTileSchedulerParams) – Tile schedule related params, including cluster shape and problem_layout_ncluster_mnl.

  • num_persistent_clusters (Int32) – Number of persistent clusters that can be launched.

  • current_work_linear_idx (Int32) – Current cluster index.

  • cta_id_in_cluster (cute.Coord) – ID of the CTA within its cluster.

  • num_tiles_executed (Int32) – Counter for executed tiles.

static create(
params: PersistentTileSchedulerParams,
block_idx: Tuple[cutlass.cutlass_dsl.Integer, cutlass.cutlass_dsl.Integer, cutlass.cutlass_dsl.Integer],
grid_dim: Tuple[cutlass.cutlass_dsl.Integer, cutlass.cutlass_dsl.Integer, cutlass.cutlass_dsl.Integer],
*,
loc=None,
ip=None,
)#

Initialize the static persistent tile scheduler.

Parameters:
  • params (PersistentTileSchedulerParams) – Parameters for the persistent tile scheduler.

  • block_idx (Tuple[Integer, Integer, Integer]) – The 3d block index in the format (bidx, bidy, bidz).

  • grid_dim (Tuple[Integer, Integer, Integer]) – The 3d grid dimensions for kernel launch.

Returns:

A StaticPersistentTileScheduler object.

Return type:

StaticPersistentTileScheduler

static get_grid_shape(
params: PersistentTileSchedulerParams,
max_active_clusters: cutlass.cutlass_dsl.Int32,
*,
loc=None,
ip=None,
) Tuple[cutlass.cutlass_dsl.Integer, cutlass.cutlass_dsl.Integer, cutlass.cutlass_dsl.Integer]#

Calculates the grid shape to be launched on GPU using problem shape, threadblock shape, and active cluster size.

Parameters:
  • params (PersistentTileSchedulerParams) – Parameters for grid shape calculation.

  • max_active_clusters (Int32) – Maximum active clusters allowed.

Returns:

The calculated 3d grid shape.

Return type:

Tuple[Integer, Integer, Integer]

_get_current_work_for_linear_idx(
current_work_linear_idx: cutlass.cutlass_dsl.Int32,
*,
loc=None,
ip=None,
) WorkTileInfo#

Compute current tile coord given current_work_linear_idx and cta_id_in_cluster.

Parameters:

current_work_linear_idx (Int32) – The linear index of the current work.

Returns:

An object containing information about the current tile coordinates and validity status.

Return type:

WorkTileInfo

get_current_work(
*,
loc=None,
ip=None,
) WorkTileInfo#
initial_work_tile_info(
*,
loc=None,
ip=None,
) WorkTileInfo#
advance_to_next_work(
*,
advance_count: int = 1,
loc=None,
ip=None,
)#
property num_tiles_executed: cutlass.cutlass_dsl.Int32#
class cutlass.utils.TensorMapUpdateMode(value)#

Bases: Enum

Enum class defining tensor map update modes.

Modes: GMEM: Update tensormap in global memory SMEM: Load tensormap from global memory to shared memory, update it in shared memory, then store back to global memory

GMEM = 1#
SMEM = 2#
class cutlass.utils.TensorMapManager(
tensormap_update_mode: TensorMapUpdateMode,
bytes_per_tensormap: int,
)#

Bases: object

Manages TensorMap operations including initialization and updates. Provides utilities to convert tensormap pointer to across different memory spaces.

tensormap_update_mode: TensorMapUpdateMode#
bytes_per_tensormap: int#
get_tensormap_ptr(
ptr: cutlass.cute.typing.Pointer,
address_space=cutlass._mlir.dialects.cute.AddressSpace.gmem,
*,
loc=None,
ip=None,
) cutlass.cute.typing.Pointer#
init_tensormap_from_atom(
copy_atom: CopyAtom,
dst_ptr: cutlass.cute.typing.Pointer,
warp_id: int,
) None#
fence_tensormap_initialization() None#
fence_tensormap_update(
tensormap_ptr: cutlass.cute.typing.Pointer,
) None#
update_tensormap(
tensor_gmem: Tuple[cutlass.cute.typing.Tensor, ...],
tma_copy_atom: Tuple[CopyAtom, ...],
tensormap_gmem_ptr: Tuple[cutlass.cute.typing.Pointer, ...],
warp_id: int,
tensormap_smem_ptr: Tuple[cutlass.cute.typing.Pointer, ...],
) None#
__init__(
tensormap_update_mode: TensorMapUpdateMode,
bytes_per_tensormap: int,
) None#
class cutlass.utils.GroupSearchResult(
group_idx: cutlass.cutlass_dsl.Int32,
cta_tile_idx_m: cutlass.cutlass_dsl.Int32,
cta_tile_idx_n: cutlass.cutlass_dsl.Int32,
problem_shape_m: cutlass.cutlass_dsl.Int32,
problem_shape_n: cutlass.cutlass_dsl.Int32,
problem_shape_k: cutlass.cutlass_dsl.Int32,
cta_tile_count_k: cutlass.cutlass_dsl.Int32,
)#

Bases: object

The result of the group search for grouped gemm.

Parameters:
  • group_idx (Int32) – The result group index

  • cta_tile_idx_m (Int32) – CTA tile index along M dimension after rasterization

  • cta_tile_idx_n (Int32) – CTA tile index along N dimension after rasterization

  • problem_shape_m (Int32) – The M dimension of the gemm problem

  • problem_shape_n (Int32) – The N dimension of the gemm problem

  • problem_shape_k (Int32) – The K dimension of the gemm problem

  • cta_tile_count_k (Int32) – Number of tiles along K dimension

__init__(
group_idx: cutlass.cutlass_dsl.Int32,
cta_tile_idx_m: cutlass.cutlass_dsl.Int32,
cta_tile_idx_n: cutlass.cutlass_dsl.Int32,
problem_shape_m: cutlass.cutlass_dsl.Int32,
problem_shape_n: cutlass.cutlass_dsl.Int32,
problem_shape_k: cutlass.cutlass_dsl.Int32,
cta_tile_count_k: cutlass.cutlass_dsl.Int32,
) None#
class cutlass.utils.GroupedGemmGroupSearchState(
start_group_idx: cutlass.cutlass_dsl.Int32,
tile_count_prev_group: cutlass.cutlass_dsl.Int32,
tile_count_searched: cutlass.cutlass_dsl.Int32,
)#

Bases: object

The state of group index search for grouped gemm.

The state will be initialized once and updated in every round of group index search.

Parameters:
  • start_group_idx (Int32) – The group idx to start the search with

  • tile_count_prev_group (Int32) – Number of tiles before the matched group

  • tile_count_searched (Int32) – Number of tiles we have searched. When the matched group is found, it records the number of tiles including the matched group

__init__(
start_group_idx: cutlass.cutlass_dsl.Int32,
tile_count_prev_group: cutlass.cutlass_dsl.Int32,
tile_count_searched: cutlass.cutlass_dsl.Int32,
) None#
cutlass.utils.create_initial_search_state() GroupedGemmGroupSearchState#

Create an initial search state for grouped gemm.

Returns:

A new search state with initial values

Return type:

GroupedGemmGroupSearchState

class cutlass.utils.GroupedGemmTileSchedulerHelper(
group_count: int,
tile_sched_params: PersistentTileSchedulerParams,
cluster_tile_shape_mnk: tuple[int, int, int],
search_state: GroupedGemmGroupSearchState,
)#

Bases: object

A helper to translate the raw block index (x, y, z) from tile scheduler to real CTA tile index for grouped gemm.

Parameters:
  • group_count (int) – Number of groups in current grouped gemm problem

  • tile_sched_params (PersistentTileSchedulerParams) – Parameter used to create the tile scheduler this helper works with

  • cluster_tile_shape_mnk (tuple[int, int, int]) – The shape of cluster tile as (m, n, k)

  • search_state (GroupedGemmGroupSearchState) – The initial search state

__init__(
group_count: int,
tile_sched_params: PersistentTileSchedulerParams,
cluster_tile_shape_mnk: tuple[int, int, int],
search_state: GroupedGemmGroupSearchState,
) None#
delinearize_z(
cta_tile_coord: tuple,
problem_shape_mnkl: cutlass.cute.typing.Tensor,
) GroupSearchResult#

Delinearize the linear z index and return GroupSearchResult.

This function should be used by warps that need to know the CTA tile index on M and N dimensions.

Parameters:
  • cta_tile_coord (tuple of Int32) – The raw CTA coordinate from tile scheduler

  • problem_shape_mnkl (cute.Tensor) – Tensor containing gemm problem size (M, N, K, L) for each group

Returns:

The search result containing group index and tile coordinates

Return type:

GroupSearchResult

search_cluster_tile_count_k(
cta_tile_coord: tuple,
problem_shape_mnkl: cutlass.cute.typing.Tensor,
) Tuple[cutlass.cutlass_dsl.Int32, cutlass.cutlass_dsl.Int32]#

Search the matched group for given linear index and compute the number of tiles along K dimension for the matched group.

This function should be used by warps that are only interested in the number of tiles along K dimension.

Parameters:
  • cta_tile_coord (tuple of Int32) – The raw CTA coordinate from tile scheduler

  • problem_shape_mnkl (cute.Tensor) – Tensor containing gemm problem size (M, N, K, L) for all groups

Returns:

A tuple containing cluster count along K dimension and the group index

Return type:

Tuple[Int32, Int32]

_prefix_sum(
value_per_thread: cutlass.cutlass_dsl.Int32,
) cutlass.cutlass_dsl.Int32#

Perform prefix sum within a full warp.

Parameters:

value_per_thread (Int32) – The value for this thread to contribute to the prefix sum

Returns:

The prefix sum result for this thread

Return type:

Int32

_get_problem_for_group(
problem_shape_mnkl: cutlass.cute.typing.Tensor,
group_idx: cutlass.cutlass_dsl.Int32,
) cutlass.cute.typing.Tensor#

Load gemm problem (m,n,k,l) for the specified group from global memory to register.

Parameters:
  • problem_shape_mnkl (cute.Tensor) – Tensor in global memory with layout (group_count, 4):(4, 1)

  • group_idx (Int32) – The index of the group to load

Returns:

The problem shape tensor for the specified group

Return type:

cute.Tensor

_get_cluster_tile_count_mn(
problem_shape: cutlass.cute.typing.Tensor,
) cutlass.cutlass_dsl.Int32#

Compute total cluster count.

Parameters:

problem_shape (cute.Tensor) – Tensor containing problem shape (m, n, k, l)

Returns:

The total cluster tile count for M and N dimensions

Return type:

Int32

_compute_cta_tile_coord(
cluster_tile_idx: cutlass.cutlass_dsl.Int32,
cta_tile_coord_in_cluster: tuple,
cluster_tile_count_m: cutlass.cutlass_dsl.Int32,
cluster_tile_count_n: cutlass.cutlass_dsl.Int32,
) tuple#

Compute CTA tile indices along M and N dimensions based on the linear index within a group.

It uses the AlongM mode to decompose the linear index onto M and N dimensions.

Parameters:
  • cluster_tile_idx (Int32) – The linear index within a group

  • cta_tile_coord_in_cluster (tuple of Int32) – CTA indices along M and N dimensions within a cluster

  • cluster_tile_count_m (Int32) – The number of clusters along M dimension of the matched group

  • cluster_tile_count_n (Int32) – The number of clusters along N dimension of the matched group

Returns:

A tuple containing CTA tile indices along M and N dimensions

Return type:

tuple of (Int32, Int32)

Search which group the linear index belongs to.

Parameters:
  • linear_idx (Int32) – The linear index to be decomposed

  • problem_shape_mnkl (cute.Tensor) – Tensor containing gemm problem size (M, N, K, L) for all groups

  • init_group_idx (Int32) – The group idx to start the search with

  • init_tile_count_searched (Int32) – The number of tiles we have searched

Returns:

The updated search state

Return type:

GroupedGemmGroupSearchState

_group_search_and_load_problem_shape(
linear_idx: cutlass.cutlass_dsl.Int32,
problem_shape_mnkl: cutlass.cute.typing.Tensor,
start_group_idx: cutlass.cutlass_dsl.Int32,
tile_count_searched: cutlass.cutlass_dsl.Int32,
) Tuple[cutlass.cutlass_dsl.Int32, cutlass.cute.typing.Tensor]#

Perform group search and load problem shape for the matched group.

Parameters:
  • linear_idx (Int32) – The linear index to be decomposed

  • problem_shape_mnkl (cute.Tensor) – Tensor containing gemm problem size (M, N, K, L) for all groups

  • start_group_idx (Int32) – The group idx to start the search with

  • tile_count_searched (Int32) – The number of tiles we have searched

Returns:

A tuple containing the final group index and the problem shape tensor

Return type:

Tuple[Int32, cute.Tensor]

class cutlass.utils.HardwareInfo(device_id: int = 0)#

Bases: object

device_id: CUDA device ID to get the hardware info.

__init__(device_id: int = 0)#
get_max_active_clusters(cluster_size: int) int#
get_l2_cache_size_in_bytes() int#
get_device_multiprocessor_count() int#
_checkCudaErrors(result) None#
_cudaGetErrorEnum(error) str#
_cuda_driver_version_ge(major: int, minor: int) bool#
_cuda_driver_version_lt(major: int, minor: int) bool#
_empty_kernel()#
_host_function()#
_get_device_function(device) None#
cutlass.utils.compute_epilogue_tile_shape(
cta_tile_shape: cutlass.cute.typing.Shape,
use_2cta_instrs: bool,
layout_d: LayoutEnum,
elem_ty_d: Type[cutlass.cutlass_dsl.Numeric],
*,
layout_c: LayoutEnum | None = None,
elem_ty_c: Type[cutlass.cutlass_dsl.Numeric] | None = None,
loc=None,
ip=None,
) cutlass.cute.typing.Tile#

Attempts to compute a reasonable epilogue tile based on block tile shape or allows the user to provide one.

Parameters:
  • cta_tile_shape (cute.Shape) – A tuple or list representing the dimensions of the CTA tile, where cta_tile_shape[0] corresponds to the height (M) and cta_tile_shape[1] corresponds to the width (N) of the tile.

  • use_2cta_instrs (bool) – A flag indicating whether the configuration is for a 2SM setup.

  • layout_d (LayoutEnum) – The layout enum of the output tensor D.

  • elem_ty_d (Type[Numeric]) – The element type of output tensor D.

  • layout_c (LayoutEnum, optional) – The layout enum of the input tensor C. Defaults to None.

  • elem_ty_c (Union[Type[Numeric], None], optional) – The element type for input tensor C. Defaults to None.

Returns:

Returns epilog tiler, which is used in subsequent epilog partitions.

Return type:

cute.Tile

Raises:

ValueError – If the computed tile cute.size does not meet minimum requirements based on CTA dimensions.

cutlass.utils.get_smem_store_op(
layout_d: LayoutEnum,
elem_ty_d: Type[cutlass.cutlass_dsl.Numeric],
elem_ty_acc: Type[cutlass.cutlass_dsl.Numeric],
tiled_tmem_load: TiledCopy,
*,
loc=None,
ip=None,
) CopyAtom#

Selects the largest vectorized smem store atom available subject to constraint of gmem layout and chosen TMEM_LOAD’s thread-value ownership.

Parameters:
  • layout_d (LayoutEnum) – The layout enum of the output tensor D.

  • elem_ty_d (Type[Numeric]) – The element type for output tensor D.

  • elem_ty_acc (Type[Numeric]) – The element type for accumulator.

  • tiled_tmem_load (cute.TiledCopy) – An instance of TiledCopy that represents the tmem load operation.

Returns:

Either SmemStoreMatrix or SimtSyncCopy, based on the input parameters.

Return type:

cute.CopyAtom

cutlass.utils.get_tmem_load_op(
cta_tile_shape: cutlass.cute.typing.Shape,
layout_d: LayoutEnum,
elem_ty_d: Type[cutlass.cutlass_dsl.Numeric],
elem_ty_acc: Type[cutlass.cutlass_dsl.Numeric],
epi_tile: cutlass.cute.typing.Tile,
use_2cta_instrs: bool,
*,
loc=None,
ip=None,
) CopyAtom#

Finds a performant TMEM_LOAD copy op for the selected epilogue tile (epi_tile), element types, and tcgen05.mma instruction used.

Parameters:
  • cta_tile_shape (cute.Shape) – A tuple or list representing the dimensions of the CTA tile.

  • layout_d (LayoutEnum) – The layout enum of the output tensor D.

  • elem_ty_d (Type[Numeric]) – The element type for output tensor D.

  • elem_ty_acc (Type[Numeric]) – The element type for accumulation.

  • epi_tile (cute.Tile) – The epilogue tile configuration.

  • use_2cta_instrs (bool) – A flag indicating whether the configuration is for 2 SMs.

Returns:

An instance of Sm100TmemLoad with the computed configuration.

Return type:

cute.CopyAtom

Raises:

ValueError – If the function cannot handle the given combination of accumulation and dimension types, or if it cannot determine the appropriate configuration based on the input parameters.

cutlass.utils.get_num_tmem_alloc_cols(
tmem_tensors: cutlass.cute.typing.Tensor | List[cutlass.cute.typing.Tensor],
rounding=True,
) int#

Get the total number of TMEM allocation columns for the given TMEM tensors.

Parameters:
  • tmem_tensors (Union[cute.Tensor, List[cute.Tensor]]) – The TMEM tensors to get the number of allocation columns for.

  • rounding (bool) – Whether to round up the number of allocation columns to the nearest power of 2.

Returns:

The total number of TMEM allocation columns.

Return type:

int

Raises:

ValueError – If the number of TMEM allocation columns exceeds the maximum capacity of 512 or is less than 32.

cutlass.utils.make_smem_layout_a(
tiled_mma: TiledMma,
mma_tiler_mnk: cutlass.cute.typing.Tile,
a_dtype: Type[cutlass.cutlass_dsl.Numeric],
num_stages: int,
*,
loc=None,
ip=None,
) cutlass.cute.typing.Layout | cutlass.cute.typing.ComposedLayout#

This function helps with:

  1. Get the partitioned shape of the A tensor based on the tiled_mma & MMA tiler.

  2. Select the heuristic SMEM layout atom based on the A tensor’s majorness, the data type, and the major mode size.

  3. cute.Tile the SMEM layout atom to the MMA tile shape.

  4. Stage the SMEM layout based on the number of stages.

Parameters:
  • tiled_mma (cute.TiledMma) – The tiled MMA used to partition tensor A

  • mma_tiler_mnk (cute.cute.Tile) – The MMA tile shape

  • a_dtype (Type[Numeric]) – The element type for tensor A

  • num_stages (int) – The number of pipeline stages for tensor A

Returns:

SMEM layout for tensor A

Return type:

Union[cute.Layout, cute.ComposedLayout]

cutlass.utils.make_smem_layout_b(
tiled_mma: TiledMma,
mma_tiler_mnk: cutlass.cute.typing.Tile,
b_dtype: Type[cutlass.cutlass_dsl.Numeric],
num_stages: int,
*,
loc=None,
ip=None,
) cutlass.cute.typing.Layout | cutlass.cute.typing.ComposedLayout#

This function helps:

  1. Get the partitioned shape of the B tensor based on the tiled_mma & MMA tiler.

  2. Select the heuristic SMEM layout atom based on the B tensor’s majorness, the data type, and the major mode size.

  3. cute.Tile the SMEM layout atom to the MMA tile shape.

  4. Stage the SMEM layout based on the number of stages.

Parameters:
  • tiled_mma (cute.TiledMma) – The tiled MMA which is used to partition the B tensor.

  • mma_tiler_mnk (cute.cute.Tile) – The MMA tile shape.

  • b_dtype (Type[Numeric]) – The element type for the B tensor.

  • num_stages (int) – The stage of the B tensor.

Returns:

SMEM layout for the B tensor.

Return type:

Union[cute.Layout, cute.ComposedLayout]

cutlass.utils.make_smem_layout_epi(
epi_dtype: Type[cutlass.cutlass_dsl.Numeric],
epi_layout: LayoutEnum,
epi_tile: cutlass.cute.typing.Tile,
epi_stage: int,
*,
loc=None,
ip=None,
) cutlass.cute.typing.Layout | cutlass.cute.typing.ComposedLayout#

This function helps:

  1. Select the heuristic SMEM layout atom based on the epilog tile shape, the epilog tensor’s majorness, and the element type.

  2. cute.Tile the SMEM layout atom to the epilog tile shape.

  3. Stage the SMEM layout based on the number of stages.

Parameters:
  • epi_dtype (Type[Numeric]) – The element type for the epilog tensor.

  • epi_layout (LayoutEnum) – The layout enum for the epilog tensor.

  • epi_tile (cute.cute.Tile) – The epilogue tile shape.

  • epi_stage (int) – The stage of the epilog tensor.

Returns:

SMEM layout for epilog tensors (usually C & D which are processed in the epilog)

Return type:

Union[cute.Layout, cute.ComposedLayout]

cutlass.utils.make_trivial_tiled_mma(
ab_dtype: Type[cutlass.cutlass_dsl.Numeric],
a_leading_mode: OperandMajorMode,
b_leading_mode: OperandMajorMode,
acc_dtype: Type[cutlass.cutlass_dsl.Numeric],
cta_group: CtaGroup,
mma_tiler_mn: Tuple[int, int],
a_source: OperandSource = cutlass._mlir.dialects.cute.MmaFragKind.smem_desc,
*,
loc=None,
ip=None,
) TiledMma#

Make a tiled MMA atom with given data type, leading dimension, cta group and mma tile shape. By default, the MMA atom is created with SMEM operand source for A.

Parameters:
  • ab_dtype (type[Numeric]) – Data type of operands A and B.

  • a_leading_mode (tcgen05.OperandMajorMode) – Leading dimension of operand A (1 for K, 0 for M/N).

  • b_leading_mode (tcgen05.OperandMajorMode) – Leading dimension of operand B (1 for K, 0 for M/N).

  • acc_dtype (type[Numeric]) – Data type of the accumulator.

  • cta_group (tcgen05.CtaGroup) – The CTA group to use.

  • mma_tiler_mn (Tuple[int, int]) – The shape (M, N, K) of the MMA tiler.

  • a_source (cutlass.cute.nvgpu.tcgen05.OperandSource) – The source of operand A (SMEM by default or TMEM).

Returns:

A tiled MMA atom.

Return type:

cute.TiledMma

Raises:

TypeError – If the data type is not supported.

cutlass.utils.make_blockscaled_trivial_tiled_mma(
ab_dtype: Type[cutlass.cutlass_dsl.Numeric],
a_leading_mode: OperandMajorMode,
b_leading_mode: OperandMajorMode,
sf_dtype: Type[cutlass.cutlass_dsl.Numeric],
sf_vec_size: int,
cta_group: CtaGroup,
mma_tiler_mn: Tuple[int, int],
a_source: OperandSource = cutlass._mlir.dialects.cute.MmaFragKind.smem_desc,
*,
loc=None,
ip=None,
) TiledMma#

Make a BlockScaled tiled MMA atom with given data type, leading dimension, cta group and mma tile shape. By default, the MMA atom is created with SMEM operand source for A.

Parameters:
  • ab_dtype (type[Numeric]) – Data type of operands A and B.

  • a_leading_mode (tcgen05.OperandMajorMode) – Leading dimension of operand A (1 for K, 0 for M/N).

  • b_leading_mode (tcgen05.OperandMajorMode) – Leading dimension of operand B (1 for K, 0 for M/N).

  • sf_dtype (type[Numeric]) – Data type of the Scale Factor.

  • sf_vec_size (int) – The vector size of the Scale Factor.

  • cta_group (tcgen05.CtaGroup) – The CTA group to use.

  • mma_tiler_mn (Tuple[int, int]) – The shape (M, N, K) of the MMA tiler.

  • a_source (cutlass.cute.nvgpu.tcgen05.OperandSource) – The source of operand A (SMEM by default or TMEM).

Returns:

A tiled MMA atom.

Return type:

cute.TiledMma

Raises:

TypeError – If the data type is not supported.