cutlass.utils#

class cutlass.utils.WorkTileInfo(
tile_idx: cutlass.cute.typing.Coord,
is_valid_tile: cutlass.cutlass_dsl.Boolean,
)#

Bases: object

A class to represent information about a work tile.

Variables:
  • tile_idx – The index of the tile.

  • is_valid_tile – Whether the tile is valid.

__init__(
tile_idx: cutlass.cute.typing.Coord,
is_valid_tile: cutlass.cutlass_dsl.Boolean,
)#
property is_valid_tile: cutlass.cutlass_dsl.Boolean#

Check latest tile returned by the scheduler is valid or not. Any scheduling requests after all tasks completed will return an invalid tile.

Returns:

The validity of the tile.

Return type:

Boolean

property tile_idx: cutlass.cute.typing.Coord#

Get the index of the tile.

Returns:

The index of the tile.

Return type:

cute.Coord

class cutlass.utils.PersistentTileSchedulerParams(
problem_shape_ntile_mnl: cutlass.cute.typing.Shape,
cluster_shape_mnk: cutlass.cute.typing.Shape,
*,
loc=None,
ip=None,
)#

Bases: object

A class to represent parameters for a persistent tile scheduler.

This class is designed to manage and compute the layout of clusters and tiles in a batched gemm problem.

Variables:
  • cluster_shape_mn – Shape of the cluster in (m, n) dimensions (K dimension cta count must be 1).

  • problem_layout_ncluster_mnl – Layout of the problem in terms of number of clusters in (m, n, l) dimensions.

__init__(
problem_shape_ntile_mnl: cutlass.cute.typing.Shape,
cluster_shape_mnk: cutlass.cute.typing.Shape,
*,
loc=None,
ip=None,
)#

Initializes the PersistentTileSchedulerParams with the given parameters.

Parameters:
  • problem_shape_ntile_mnl (cute.Shape) – The shape of the problem in terms of number of CTA (Cooperative Thread Array) in (m, n, l) dimensions.

  • cluster_shape_mnk (cute.Shape) – The shape of the cluster in (m, n) dimensions.

Raises:

ValueError – If cluster_shape_k is not 1.

get_grid_shape(
max_active_clusters: cutlass.cutlass_dsl.Int32,
*,
loc=None,
ip=None,
) Tuple[cutlass.cutlass_dsl.Integer, cutlass.cutlass_dsl.Integer, cutlass.cutlass_dsl.Integer]#

Computes the grid shape based on the maximum active clusters allowed.

Parameters:

max_active_clusters (Int32) – The maximum number of active clusters that can run in one wave.

Returns:

A tuple containing the grid shape in (m, n, persistent_clusters). - m: self.cluster_shape_m. - n: self.cluster_shape_n. - persistent_clusters: Number of persistent clusters that can run.

class cutlass.utils.StaticPersistentTileScheduler(
params: PersistentTileSchedulerParams,
num_persistent_clusters: cutlass.cutlass_dsl.Int32,
current_work_linear_idx: cutlass.cutlass_dsl.Int32,
cta_id_in_cluster: cutlass.cute.typing.Coord,
num_tiles_executed: cutlass.cutlass_dsl.Int32,
)#

Bases: object

A scheduler for static persistent tile execution in CUTLASS/CuTe kernels.

Variables:
  • params – Tile schedule related params, including cluster shape and problem_layout_ncluster_mnl

  • num_persistent_clusters – Number of persistent clusters that can be launched

  • cta_id_in_cluster – ID of the CTA within its cluster

  • _num_tiles_executed – Counter for executed tiles

  • _current_work_linear_idx – Current cluster index

__init__(
params: PersistentTileSchedulerParams,
num_persistent_clusters: cutlass.cutlass_dsl.Int32,
current_work_linear_idx: cutlass.cutlass_dsl.Int32,
cta_id_in_cluster: cutlass.cute.typing.Coord,
num_tiles_executed: cutlass.cutlass_dsl.Int32,
)#

Initializes the StaticPersistentTileScheduler with the given parameters.

Parameters:
  • params (PersistentTileSchedulerParams) – Tile schedule related params, including cluster shape and problem_layout_ncluster_mnl.

  • num_persistent_clusters (Int32) – Number of persistent clusters that can be launched.

  • current_work_linear_idx (Int32) – Current cluster index.

  • cta_id_in_cluster (cute.Coord) – ID of the CTA within its cluster.

  • num_tiles_executed (Int32) – Counter for executed tiles.

create(
params: PersistentTileSchedulerParams,
block_idx: Tuple[cutlass.cutlass_dsl.Integer, cutlass.cutlass_dsl.Integer, cutlass.cutlass_dsl.Integer],
grid_dim: Tuple[cutlass.cutlass_dsl.Integer, cutlass.cutlass_dsl.Integer, cutlass.cutlass_dsl.Integer],
*,
loc=None,
ip=None,
)#

Initialize the static persistent tile scheduler.

Parameters:
  • params (PersistentTileSchedulerParams) – Parameters for the persistent tile scheduler.

  • block_idx (Tuple[Integer, Integer, Integer]) – The 3d block index in the format (bidx, bidy, bidz).

  • grid_dim (Tuple[Integer, Integer, Integer]) – The 3d grid dimensions for kernel launch.

Returns:

A StaticPersistentTileScheduler object.

Return type:

StaticPersistentTileScheduler

static get_grid_shape(
params: PersistentTileSchedulerParams,
max_active_clusters: cutlass.cutlass_dsl.Int32,
*,
loc=None,
ip=None,
) Tuple[cutlass.cutlass_dsl.Integer, cutlass.cutlass_dsl.Integer, cutlass.cutlass_dsl.Integer]#

Calculates the grid shape to be launched on GPU using problem shape, threadblock shape, and active cluster size.

Parameters:
  • params (PersistentTileSchedulerParams) – Parameters for grid shape calculation.

  • max_active_clusters (Int32) – Maximum active clusters allowed.

Returns:

The calculated 3d grid shape.

Return type:

Tuple[Integer, Integer, Integer]

_get_current_work_for_linear_idx(
current_work_linear_idx: cutlass.cutlass_dsl.Int32,
*,
loc=None,
ip=None,
) WorkTileInfo#

Compute current tile coord given current_work_linear_idx and cta_id_in_cluster.

Parameters:

current_work_linear_idx (Int32) – The linear index of the current work.

Returns:

An object containing information about the current tile coordinates and validity status.

Return type:

WorkTileInfo

get_current_work(
*,
loc=None,
ip=None,
) WorkTileInfo#
initial_work_tile_info(
*,
loc=None,
ip=None,
) WorkTileInfo#
advance_to_next_work(
*,
advance_count: int = 1,
loc=None,
ip=None,
)#
property num_tiles_executed: cutlass.cutlass_dsl.Int32#
class cutlass.utils.TensorMapUpdateMode(value)#

Bases: Enum

Enum class defining tensor map update modes.

Modes: GMEM: Update tensormap in global memory SMEM: Load tensormap from global memory to shared memory, update it in shared memory, then store back to global memory

GMEM = 1#
SMEM = 2#
class cutlass.utils.TensorMapManager(
tensormap_update_mode: TensorMapUpdateMode,
bytes_per_tensormap: int,
)#

Bases: object

Manages TensorMap operations including initialization and updates. Provides utilities to convert tensormap pointer to across different memory spaces.

tensormap_update_mode: TensorMapUpdateMode#
bytes_per_tensormap: int#
get_tensormap_ptr(
ptr: cutlass.cute.typing.Pointer,
address_space=cutlass._mlir.dialects.cute.AddressSpace.gmem,
) cutlass.cute.typing.Pointer#
init_tensormap_from_atom(
copy_atom: CopyAtom,
dst_ptr: cutlass.cute.typing.Pointer,
warp_id: int,
) None#
fence_tensormap_initialization() None#
fence_tensormap_update(
tensormap_ptr: cutlass.cute.typing.Pointer,
) None#
update_tensormap(
tensor_gmem: Tuple[cutlass.cute.typing.Tensor, ...],
tma_copy_atom: Tuple[CopyAtom, ...],
tensormap_gmem_ptr: Tuple[cutlass.cute.typing.Pointer, ...],
warp_id: int,
tensormap_smem_ptr: Tuple[cutlass.cute.typing.Pointer, ...],
) None#
__init__(
tensormap_update_mode: TensorMapUpdateMode,
bytes_per_tensormap: int,
) None#
class cutlass.utils.GroupSearchResult(
group_idx: cutlass.cutlass_dsl.Int32,
cta_tile_idx_m: cutlass.cutlass_dsl.Int32,
cta_tile_idx_n: cutlass.cutlass_dsl.Int32,
problem_shape_m: cutlass.cutlass_dsl.Int32,
problem_shape_n: cutlass.cutlass_dsl.Int32,
problem_shape_k: cutlass.cutlass_dsl.Int32,
cta_tile_count_k: cutlass.cutlass_dsl.Int32,
)#

Bases: object

The result of the group search for grouped gemm.

Parameters:
  • group_idx (Int32) – The result group index

  • cta_tile_idx_m (Int32) – CTA tile index along M dimension after rasterization

  • cta_tile_idx_n (Int32) – CTA tile index along N dimension after rasterization

  • problem_shape_m (Int32) – The M dimension of the gemm problem

  • problem_shape_n (Int32) – The N dimension of the gemm problem

  • problem_shape_k (Int32) – The K dimension of the gemm problem

  • cta_tile_count_k (Int32) – Number of tiles along K dimension

__init__(
group_idx: cutlass.cutlass_dsl.Int32,
cta_tile_idx_m: cutlass.cutlass_dsl.Int32,
cta_tile_idx_n: cutlass.cutlass_dsl.Int32,
problem_shape_m: cutlass.cutlass_dsl.Int32,
problem_shape_n: cutlass.cutlass_dsl.Int32,
problem_shape_k: cutlass.cutlass_dsl.Int32,
cta_tile_count_k: cutlass.cutlass_dsl.Int32,
) None#
class cutlass.utils.GroupedGemmGroupSearchState(
start_group_idx: cutlass.cutlass_dsl.Int32,
tile_count_prev_group: cutlass.cutlass_dsl.Int32,
tile_count_searched: cutlass.cutlass_dsl.Int32,
)#

Bases: object

The state of group index search for grouped gemm.

The state will be initialized once and updated in every round of group index search.

Parameters:
  • start_group_idx (Int32) – The group idx to start the search with

  • tile_count_prev_group (Int32) – Number of tiles before the matched group

  • tile_count_searched (Int32) – Number of tiles we have searched. When the matched group is found, it records the number of tiles including the matched group

__init__(
start_group_idx: cutlass.cutlass_dsl.Int32,
tile_count_prev_group: cutlass.cutlass_dsl.Int32,
tile_count_searched: cutlass.cutlass_dsl.Int32,
) None#
cutlass.utils.create_initial_search_state() GroupedGemmGroupSearchState#

Create an initial search state for grouped gemm.

Returns:

A new search state with initial values

Return type:

GroupedGemmGroupSearchState

class cutlass.utils.GroupedGemmTileSchedulerHelper(
group_count: int,
tile_sched_params: PersistentTileSchedulerParams,
cluster_tile_shape_mnk: tuple[int, int, int],
search_state: GroupedGemmGroupSearchState,
)#

Bases: object

A helper to translate the raw block index (x, y, z) from tile scheduler to real CTA tile index for grouped gemm.

Parameters:
  • group_count (int) – Number of groups in current grouped gemm problem

  • tile_sched_params (PersistentTileSchedulerParams) – Parameter used to create the tile scheduler this helper works with

  • cluster_tile_shape_mnk (tuple[int, int, int]) – The shape of cluster tile as (m, n, k)

  • search_state (GroupedGemmGroupSearchState) – The initial search state

__init__(
group_count: int,
tile_sched_params: PersistentTileSchedulerParams,
cluster_tile_shape_mnk: tuple[int, int, int],
search_state: GroupedGemmGroupSearchState,
) None#
delinearize_z(
cta_tile_coord: tuple,
problem_shape_mnkl: cutlass.cute.typing.Tensor,
) GroupSearchResult#

Delinearize the linear z index and return GroupSearchResult.

This function should be used by warps that need to know the CTA tile index on M and N dimensions.

Parameters:
  • cta_tile_coord (tuple of Int32) – The raw CTA coordinate from tile scheduler

  • problem_shape_mnkl (cute.Tensor) – Tensor containing gemm problem size (M, N, K, L) for each group

Returns:

The search result containing group index and tile coordinates

Return type:

GroupSearchResult

search_cluster_tile_count_k(
cta_tile_coord: tuple,
problem_shape_mnkl: cutlass.cute.typing.Tensor,
) Tuple[cutlass.cutlass_dsl.Int32, cutlass.cutlass_dsl.Int32]#

Search the matched group for given linear index and compute the number of tiles along K dimension for the matched group.

This function should be used by warps that are only interested in the number of tiles along K dimension.

Parameters:
  • cta_tile_coord (tuple of Int32) – The raw CTA coordinate from tile scheduler

  • problem_shape_mnkl (cute.Tensor) – Tensor containing gemm problem size (M, N, K, L) for all groups

Returns:

A tuple containing cluster count along K dimension and the group index

Return type:

Tuple[Int32, Int32]

_prefix_sum(
value_per_thread: cutlass.cutlass_dsl.Int32,
) cutlass.cutlass_dsl.Int32#

Perform prefix sum within a full warp.

Parameters:

value_per_thread (Int32) – The value for this thread to contribute to the prefix sum

Returns:

The prefix sum result for this thread

Return type:

Int32

_get_problem_for_group(
problem_shape_mnkl: cutlass.cute.typing.Tensor,
group_idx: cutlass.cutlass_dsl.Int32,
) cutlass.cute.typing.Tensor#

Load gemm problem (m,n,k,l) for the specified group from global memory to register.

Parameters:
  • problem_shape_mnkl (cute.Tensor) – Tensor in global memory with layout (group_count, 4):(4, 1)

  • group_idx (Int32) – The index of the group to load

Returns:

The problem shape tensor for the specified group

Return type:

cute.Tensor

_get_cluster_tile_count_mn(
problem_shape: cutlass.cute.typing.Tensor,
) cutlass.cutlass_dsl.Int32#

Compute total cluster count.

Parameters:

problem_shape (cute.Tensor) – Tensor containing problem shape (m, n, k, l)

Returns:

The total cluster tile count for M and N dimensions

Return type:

Int32

_compute_cta_tile_coord(
cluster_tile_idx: cutlass.cutlass_dsl.Int32,
cta_tile_coord_in_cluster: tuple,
cluster_tile_count_m: cutlass.cutlass_dsl.Int32,
cluster_tile_count_n: cutlass.cutlass_dsl.Int32,
) tuple#

Compute CTA tile indices along M and N dimensions based on the linear index within a group.

It uses the AlongM mode to decompose the linear index onto M and N dimensions.

Parameters:
  • cluster_tile_idx (Int32) – The linear index within a group

  • cta_tile_coord_in_cluster (tuple of Int32) – CTA indices along M and N dimensions within a cluster

  • cluster_tile_count_m (Int32) – The number of clusters along M dimension of the matched group

  • cluster_tile_count_n (Int32) – The number of clusters along N dimension of the matched group

Returns:

A tuple containing CTA tile indices along M and N dimensions

Return type:

tuple of (Int32, Int32)

Search which group the linear index belongs to.

Parameters:
  • linear_idx (Int32) – The linear index to be decomposed

  • problem_shape_mnkl (cute.Tensor) – Tensor containing gemm problem size (M, N, K, L) for all groups

  • init_group_idx (Int32) – The group idx to start the search with

  • init_tile_count_searched (Int32) – The number of tiles we have searched

Returns:

The updated search state

Return type:

GroupedGemmGroupSearchState

_group_search_and_load_problem_shape(
linear_idx: cutlass.cutlass_dsl.Int32,
problem_shape_mnkl: cutlass.cute.typing.Tensor,
start_group_idx: cutlass.cutlass_dsl.Int32,
tile_count_searched: cutlass.cutlass_dsl.Int32,
) Tuple[cutlass.cutlass_dsl.Int32, cutlass.cute.typing.Tensor]#

Perform group search and load problem shape for the matched group.

Parameters:
  • linear_idx (Int32) – The linear index to be decomposed

  • problem_shape_mnkl (cute.Tensor) – Tensor containing gemm problem size (M, N, K, L) for all groups

  • start_group_idx (Int32) – The group idx to start the search with

  • tile_count_searched (Int32) – The number of tiles we have searched

Returns:

A tuple containing the final group index and the problem shape tensor

Return type:

Tuple[Int32, cute.Tensor]

class cutlass.utils.HardwareInfo(device_id: int = 0)#

Bases: object

device_id: CUDA device ID to get the hardware info.

__init__(device_id: int = 0)#
get_max_active_clusters(cluster_size: int) int#
get_l2_cache_size_in_bytes() int#
get_device_multiprocessor_count() int#
_checkCudaErrors(result) None#
_cudaGetErrorEnum(error) str#
_cuda_driver_version_ge(major: int, minor: int) bool#
_cuda_driver_version_lt(major: int, minor: int) bool#
_empty_kernel()#
_host_function()#
_get_device_function() None#