cutlass.utils#
- class cutlass.utils.WorkTileInfo(
- tile_idx: cutlass.cute.typing.Coord,
- is_valid_tile: cutlass.cutlass_dsl.Boolean,
Bases:
object
A class to represent information about a work tile.
- Variables:
tile_idx – The index of the tile.
is_valid_tile – Whether the tile is valid.
- __init__(
- tile_idx: cutlass.cute.typing.Coord,
- is_valid_tile: cutlass.cutlass_dsl.Boolean,
- property is_valid_tile: cutlass.cutlass_dsl.Boolean#
Check latest tile returned by the scheduler is valid or not. Any scheduling requests after all tasks completed will return an invalid tile.
- Returns:
The validity of the tile.
- Return type:
Boolean
- property tile_idx: cutlass.cute.typing.Coord#
Get the index of the tile.
- Returns:
The index of the tile.
- Return type:
cute.Coord
- class cutlass.utils.PersistentTileSchedulerParams(
- problem_shape_ntile_mnl: cutlass.cute.typing.Shape,
- cluster_shape_mnk: cutlass.cute.typing.Shape,
- *,
- loc=None,
- ip=None,
Bases:
object
A class to represent parameters for a persistent tile scheduler.
This class is designed to manage and compute the layout of clusters and tiles in a batched gemm problem.
- Variables:
cluster_shape_mn – Shape of the cluster in (m, n) dimensions (K dimension cta count must be 1).
problem_layout_ncluster_mnl – Layout of the problem in terms of number of clusters in (m, n, l) dimensions.
- __init__(
- problem_shape_ntile_mnl: cutlass.cute.typing.Shape,
- cluster_shape_mnk: cutlass.cute.typing.Shape,
- *,
- loc=None,
- ip=None,
Initializes the PersistentTileSchedulerParams with the given parameters.
- Parameters:
problem_shape_ntile_mnl (cute.Shape) – The shape of the problem in terms of number of CTA (Cooperative Thread Array) in (m, n, l) dimensions.
cluster_shape_mnk (cute.Shape) – The shape of the cluster in (m, n) dimensions.
- Raises:
ValueError – If cluster_shape_k is not 1.
- get_grid_shape(
- max_active_clusters: cutlass.cutlass_dsl.Int32,
- *,
- loc=None,
- ip=None,
Computes the grid shape based on the maximum active clusters allowed.
- Parameters:
max_active_clusters (Int32) – The maximum number of active clusters that can run in one wave.
- Returns:
A tuple containing the grid shape in (m, n, persistent_clusters). - m: self.cluster_shape_m. - n: self.cluster_shape_n. - persistent_clusters: Number of persistent clusters that can run.
- class cutlass.utils.StaticPersistentTileScheduler(
- params: PersistentTileSchedulerParams,
- num_persistent_clusters: cutlass.cutlass_dsl.Int32,
- current_work_linear_idx: cutlass.cutlass_dsl.Int32,
- cta_id_in_cluster: cutlass.cute.typing.Coord,
- num_tiles_executed: cutlass.cutlass_dsl.Int32,
Bases:
object
A scheduler for static persistent tile execution in CUTLASS/CuTe kernels.
- Variables:
params – Tile schedule related params, including cluster shape and problem_layout_ncluster_mnl
num_persistent_clusters – Number of persistent clusters that can be launched
cta_id_in_cluster – ID of the CTA within its cluster
_num_tiles_executed – Counter for executed tiles
_current_work_linear_idx – Current cluster index
- __init__(
- params: PersistentTileSchedulerParams,
- num_persistent_clusters: cutlass.cutlass_dsl.Int32,
- current_work_linear_idx: cutlass.cutlass_dsl.Int32,
- cta_id_in_cluster: cutlass.cute.typing.Coord,
- num_tiles_executed: cutlass.cutlass_dsl.Int32,
Initializes the StaticPersistentTileScheduler with the given parameters.
- Parameters:
params (PersistentTileSchedulerParams) – Tile schedule related params, including cluster shape and problem_layout_ncluster_mnl.
num_persistent_clusters (Int32) – Number of persistent clusters that can be launched.
current_work_linear_idx (Int32) – Current cluster index.
cta_id_in_cluster (cute.Coord) – ID of the CTA within its cluster.
num_tiles_executed (Int32) – Counter for executed tiles.
- create(
- params: PersistentTileSchedulerParams,
- block_idx: Tuple[cutlass.cutlass_dsl.Integer, cutlass.cutlass_dsl.Integer, cutlass.cutlass_dsl.Integer],
- grid_dim: Tuple[cutlass.cutlass_dsl.Integer, cutlass.cutlass_dsl.Integer, cutlass.cutlass_dsl.Integer],
- *,
- loc=None,
- ip=None,
Initialize the static persistent tile scheduler.
- Parameters:
params (PersistentTileSchedulerParams) – Parameters for the persistent tile scheduler.
block_idx (Tuple[Integer, Integer, Integer]) – The 3d block index in the format (bidx, bidy, bidz).
grid_dim (Tuple[Integer, Integer, Integer]) – The 3d grid dimensions for kernel launch.
- Returns:
A StaticPersistentTileScheduler object.
- Return type:
- static get_grid_shape(
- params: PersistentTileSchedulerParams,
- max_active_clusters: cutlass.cutlass_dsl.Int32,
- *,
- loc=None,
- ip=None,
Calculates the grid shape to be launched on GPU using problem shape, threadblock shape, and active cluster size.
- Parameters:
params (PersistentTileSchedulerParams) – Parameters for grid shape calculation.
max_active_clusters (Int32) – Maximum active clusters allowed.
- Returns:
The calculated 3d grid shape.
- Return type:
Tuple[Integer, Integer, Integer]
- _get_current_work_for_linear_idx(
- current_work_linear_idx: cutlass.cutlass_dsl.Int32,
- *,
- loc=None,
- ip=None,
Compute current tile coord given current_work_linear_idx and cta_id_in_cluster.
- Parameters:
current_work_linear_idx (Int32) – The linear index of the current work.
- Returns:
An object containing information about the current tile coordinates and validity status.
- Return type:
- get_current_work(
- *,
- loc=None,
- ip=None,
- initial_work_tile_info(
- *,
- loc=None,
- ip=None,
- advance_to_next_work(
- *,
- advance_count: int = 1,
- loc=None,
- ip=None,
- property num_tiles_executed: cutlass.cutlass_dsl.Int32#
- class cutlass.utils.TensorMapUpdateMode(value)#
Bases:
Enum
Enum class defining tensor map update modes.
Modes: GMEM: Update tensormap in global memory SMEM: Load tensormap from global memory to shared memory, update it in shared memory, then store back to global memory
- GMEM = 1#
- SMEM = 2#
- class cutlass.utils.TensorMapManager(
- tensormap_update_mode: TensorMapUpdateMode,
- bytes_per_tensormap: int,
Bases:
object
Manages TensorMap operations including initialization and updates. Provides utilities to convert tensormap pointer to across different memory spaces.
- tensormap_update_mode: TensorMapUpdateMode#
- bytes_per_tensormap: int#
- get_tensormap_ptr(
- ptr: cutlass.cute.typing.Pointer,
- address_space=cutlass._mlir.dialects.cute.AddressSpace.gmem,
- init_tensormap_from_atom(
- copy_atom: CopyAtom,
- dst_ptr: cutlass.cute.typing.Pointer,
- warp_id: int,
- fence_tensormap_initialization() None #
- fence_tensormap_update(
- tensormap_ptr: cutlass.cute.typing.Pointer,
- update_tensormap(
- tensor_gmem: Tuple[cutlass.cute.typing.Tensor, ...],
- tma_copy_atom: Tuple[CopyAtom, ...],
- tensormap_gmem_ptr: Tuple[cutlass.cute.typing.Pointer, ...],
- warp_id: int,
- tensormap_smem_ptr: Tuple[cutlass.cute.typing.Pointer, ...],
- __init__(
- tensormap_update_mode: TensorMapUpdateMode,
- bytes_per_tensormap: int,
- class cutlass.utils.GroupSearchResult(
- group_idx: cutlass.cutlass_dsl.Int32,
- cta_tile_idx_m: cutlass.cutlass_dsl.Int32,
- cta_tile_idx_n: cutlass.cutlass_dsl.Int32,
- problem_shape_m: cutlass.cutlass_dsl.Int32,
- problem_shape_n: cutlass.cutlass_dsl.Int32,
- problem_shape_k: cutlass.cutlass_dsl.Int32,
- cta_tile_count_k: cutlass.cutlass_dsl.Int32,
Bases:
object
The result of the group search for grouped gemm.
- Parameters:
group_idx (Int32) – The result group index
cta_tile_idx_m (Int32) – CTA tile index along M dimension after rasterization
cta_tile_idx_n (Int32) – CTA tile index along N dimension after rasterization
problem_shape_m (Int32) – The M dimension of the gemm problem
problem_shape_n (Int32) – The N dimension of the gemm problem
problem_shape_k (Int32) – The K dimension of the gemm problem
cta_tile_count_k (Int32) – Number of tiles along K dimension
- __init__(
- group_idx: cutlass.cutlass_dsl.Int32,
- cta_tile_idx_m: cutlass.cutlass_dsl.Int32,
- cta_tile_idx_n: cutlass.cutlass_dsl.Int32,
- problem_shape_m: cutlass.cutlass_dsl.Int32,
- problem_shape_n: cutlass.cutlass_dsl.Int32,
- problem_shape_k: cutlass.cutlass_dsl.Int32,
- cta_tile_count_k: cutlass.cutlass_dsl.Int32,
- class cutlass.utils.GroupedGemmGroupSearchState(
- start_group_idx: cutlass.cutlass_dsl.Int32,
- tile_count_prev_group: cutlass.cutlass_dsl.Int32,
- tile_count_searched: cutlass.cutlass_dsl.Int32,
Bases:
object
The state of group index search for grouped gemm.
The state will be initialized once and updated in every round of group index search.
- Parameters:
start_group_idx (Int32) – The group idx to start the search with
tile_count_prev_group (Int32) – Number of tiles before the matched group
tile_count_searched (Int32) – Number of tiles we have searched. When the matched group is found, it records the number of tiles including the matched group
- __init__(
- start_group_idx: cutlass.cutlass_dsl.Int32,
- tile_count_prev_group: cutlass.cutlass_dsl.Int32,
- tile_count_searched: cutlass.cutlass_dsl.Int32,
- cutlass.utils.create_initial_search_state() GroupedGemmGroupSearchState #
Create an initial search state for grouped gemm.
- Returns:
A new search state with initial values
- Return type:
- class cutlass.utils.GroupedGemmTileSchedulerHelper(
- group_count: int,
- tile_sched_params: PersistentTileSchedulerParams,
- cluster_tile_shape_mnk: tuple[int, int, int],
- search_state: GroupedGemmGroupSearchState,
Bases:
object
A helper to translate the raw block index (x, y, z) from tile scheduler to real CTA tile index for grouped gemm.
- Parameters:
group_count (int) – Number of groups in current grouped gemm problem
tile_sched_params (PersistentTileSchedulerParams) – Parameter used to create the tile scheduler this helper works with
cluster_tile_shape_mnk (tuple[int, int, int]) – The shape of cluster tile as (m, n, k)
search_state (GroupedGemmGroupSearchState) – The initial search state
- __init__(
- group_count: int,
- tile_sched_params: PersistentTileSchedulerParams,
- cluster_tile_shape_mnk: tuple[int, int, int],
- search_state: GroupedGemmGroupSearchState,
- delinearize_z(
- cta_tile_coord: tuple,
- problem_shape_mnkl: cutlass.cute.typing.Tensor,
Delinearize the linear z index and return GroupSearchResult.
This function should be used by warps that need to know the CTA tile index on M and N dimensions.
- Parameters:
cta_tile_coord (tuple of Int32) – The raw CTA coordinate from tile scheduler
problem_shape_mnkl (cute.Tensor) – Tensor containing gemm problem size (M, N, K, L) for each group
- Returns:
The search result containing group index and tile coordinates
- Return type:
- search_cluster_tile_count_k(
- cta_tile_coord: tuple,
- problem_shape_mnkl: cutlass.cute.typing.Tensor,
Search the matched group for given linear index and compute the number of tiles along K dimension for the matched group.
This function should be used by warps that are only interested in the number of tiles along K dimension.
- Parameters:
cta_tile_coord (tuple of Int32) – The raw CTA coordinate from tile scheduler
problem_shape_mnkl (cute.Tensor) – Tensor containing gemm problem size (M, N, K, L) for all groups
- Returns:
A tuple containing cluster count along K dimension and the group index
- Return type:
Tuple[Int32, Int32]
- _prefix_sum(
- value_per_thread: cutlass.cutlass_dsl.Int32,
Perform prefix sum within a full warp.
- Parameters:
value_per_thread (Int32) – The value for this thread to contribute to the prefix sum
- Returns:
The prefix sum result for this thread
- Return type:
Int32
- _get_problem_for_group(
- problem_shape_mnkl: cutlass.cute.typing.Tensor,
- group_idx: cutlass.cutlass_dsl.Int32,
Load gemm problem (m,n,k,l) for the specified group from global memory to register.
- Parameters:
problem_shape_mnkl (cute.Tensor) – Tensor in global memory with layout (group_count, 4):(4, 1)
group_idx (Int32) – The index of the group to load
- Returns:
The problem shape tensor for the specified group
- Return type:
cute.Tensor
- _get_cluster_tile_count_mn(
- problem_shape: cutlass.cute.typing.Tensor,
Compute total cluster count.
- Parameters:
problem_shape (cute.Tensor) – Tensor containing problem shape (m, n, k, l)
- Returns:
The total cluster tile count for M and N dimensions
- Return type:
Int32
- _compute_cta_tile_coord(
- cluster_tile_idx: cutlass.cutlass_dsl.Int32,
- cta_tile_coord_in_cluster: tuple,
- cluster_tile_count_m: cutlass.cutlass_dsl.Int32,
- cluster_tile_count_n: cutlass.cutlass_dsl.Int32,
Compute CTA tile indices along M and N dimensions based on the linear index within a group.
It uses the AlongM mode to decompose the linear index onto M and N dimensions.
- Parameters:
cluster_tile_idx (Int32) – The linear index within a group
cta_tile_coord_in_cluster (tuple of Int32) – CTA indices along M and N dimensions within a cluster
cluster_tile_count_m (Int32) – The number of clusters along M dimension of the matched group
cluster_tile_count_n (Int32) – The number of clusters along N dimension of the matched group
- Returns:
A tuple containing CTA tile indices along M and N dimensions
- Return type:
tuple of (Int32, Int32)
- _group_search(
- linear_idx: cutlass.cutlass_dsl.Int32,
- problem_shape_mnkl: cutlass.cute.typing.Tensor,
- init_group_idx: cutlass.cutlass_dsl.Int32,
- init_tile_count_searched: cutlass.cutlass_dsl.Int32,
Search which group the linear index belongs to.
- Parameters:
linear_idx (Int32) – The linear index to be decomposed
problem_shape_mnkl (cute.Tensor) – Tensor containing gemm problem size (M, N, K, L) for all groups
init_group_idx (Int32) – The group idx to start the search with
init_tile_count_searched (Int32) – The number of tiles we have searched
- Returns:
The updated search state
- Return type:
- _group_search_and_load_problem_shape(
- linear_idx: cutlass.cutlass_dsl.Int32,
- problem_shape_mnkl: cutlass.cute.typing.Tensor,
- start_group_idx: cutlass.cutlass_dsl.Int32,
- tile_count_searched: cutlass.cutlass_dsl.Int32,
Perform group search and load problem shape for the matched group.
- Parameters:
linear_idx (Int32) – The linear index to be decomposed
problem_shape_mnkl (cute.Tensor) – Tensor containing gemm problem size (M, N, K, L) for all groups
start_group_idx (Int32) – The group idx to start the search with
tile_count_searched (Int32) – The number of tiles we have searched
- Returns:
A tuple containing the final group index and the problem shape tensor
- Return type:
Tuple[Int32, cute.Tensor]
- class cutlass.utils.HardwareInfo(device_id: int = 0)#
Bases:
object
device_id: CUDA device ID to get the hardware info.
- __init__(device_id: int = 0)#
- get_max_active_clusters(cluster_size: int) int #
- get_l2_cache_size_in_bytes() int #
- get_device_multiprocessor_count() int #
- _checkCudaErrors(result) None #
- _cudaGetErrorEnum(error) str #
- _cuda_driver_version_ge(major: int, minor: int) bool #
- _cuda_driver_version_lt(major: int, minor: int) bool #
- _empty_kernel()#
- _host_function()#
- _get_device_function() None #