gemm.h

Functions for matrix multiplication.

Typedefs

typedef void *NVTEMatmulConfig

Configuration for matrix multiplication.

Enums

enum NVTEMatmulConfigAttribute

Type of option for matrix multiplication.

Values:

enumerator kNVTEMatmulConfigBiasTensor

Bias tensor

If provided, the bias tensor is applied in the GEMM epilogue.

enumerator kNVTEMatmulConfigDBiasTensor

Bias gradient tensor

If provided, the bias gradient tensor will be filled in the GEMM epilogue.

enumerator kNVTEMatmulConfigWithGELUEpilogue

Whether to compute GELU in GEMM epilogue.

enumerator kNVTEMatmulConfigWithDGELUEpilogue

Whether to compute GELU backward in GEMM epilogue.

enumerator kNVTEMatmulConfigEpilogueAuxTensor

Auxilliary tensor for GEMM epilogue.

For GELU, this will be filled with the GELU input. For GELU backward, this is expected to already be filled with the GELU input.

enumerator kNVTEMatmulConfigUseSplitAccumulator

Whether to use split accumulator for FP8 GEMM.

enumerator kNVTEMatmulConfigSMCount

Number of streaming multiprocessors to use in GEMM kernel.

enumerator kNVTEMatmulConfigNumAttributes

Functions

NVTEMatmulConfig nvte_create_matmul_config()

Create a matrix multiplication configuration.

void nvte_get_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr, void *buf, size_t size_in_bytes, size_t *size_written)

Query an option in matrix multiplication configuration.

Parameters:
  • config[in] Matrix multiplication configuration.

  • attr[in] Option type.

  • buf[out] Memory address to write option value. Ignored if NULL.

  • size_in_bytes[in] Size of buf.

  • size_written[out] Number of bytes that have been written to buf. If buf is NULL, then the number of bytes that would have been written.

void nvte_set_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr, const void *buf, size_t size_in_bytes)

Set an option in matrix multiplication configuration.

Parameters:
  • config[in] Matrix multiplication configuration.

  • attr[in] Option type.

  • buf[out] Memory address to read option value.

  • size_in_bytes[in] Size of buf.

void nvte_destroy_matmul_config(NVTEMatmulConfig config)

Destroy a matrix multiplication configuration.

void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias, NVTETensor pre_gelu_out, bool transa, bool transb, bool grad, NVTETensor workspace, bool accumulate, bool use_split_accumulator, int math_sm_count, cudaStream_t stream)

Compute matrix multiplication of 2 matrices, potentially fused with other operations (deprecated).

This has been deprecated in favor of nvte_cublas_gemm_v2.

Computes:

  • D = AB if both bias and pre_gelu_out are empty tensors

  • D = AB + bias if pre_gelu_out is empty and bias is not empty

  • D = GELU(AB + bias) if both bias and pre_gelu_out are not empty tensors

Parameters:
  • A[in] The A matrix.

  • B[in] The B matrix.

  • D[inout] Output matrix.

  • bias[in] Bias tensor.

  • pre_gelu_out[inout] Output matrix before GELU activation.

  • transa[in] Whether A matrix is transposed.

  • transb[in] Whether B matrix is transposed.

  • grad[in] Whether this operation is part of the gradient computation.

  • workspace[out] Workspace tensor.

  • accumulate[in] Whether to accumulate the result into the D matrix.

  • use_split_accumulator[in] Whether to use split accumulator in the FP8 GEMM.

  • math_sm_count[in] Number of GPU SMs to use (default=0: use cuBLAS heuristics)

  • stream[in] CUDA stream used for the operation.

void nvte_cublas_gemm_v2(int transa, int transb, const float *alpha, const NVTETensor A, const NVTETensor B, const float *beta, const NVTETensor C, NVTETensor D, NVTETensor workspace, NVTEMatmulConfig config, cudaStream_t stream)

Compute matrix multiplication of 2 matrices, potentially fused with other operations.

Computes:

  • D = alpha * op(A) * op(B) + beta * C

Parameters:
  • transa[in] Whether to transpose A matrix.

  • transb[in] Whether to transpose B matrix.

  • alpha[in] Scaling factor applied to matmul output.

  • A[in] A matrix.

  • B[in] B matrix.

  • beta[in] Scaling factor applied to C matrix.

  • C[in] C matrix.

  • D[out] Output matrix.

  • workspace[in] Workspace tensor.

  • config[in] Additional configuration.

  • stream[in] CUDA stream used for the operation.

void nvte_cublas_gemm_scaled(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias, NVTETensor pre_gelu_out, bool transa, bool transb, bool grad, NVTETensor workspace, float alpha, float beta, bool use_split_accumulator, int math_sm_count, cudaStream_t stream)

Compute matrix multiplication of 2 matrices, potentially fused with other operations, allowing for using a scaling factor for the GEMM result and the accumulation input (deprecated)

This has been deprecated in favor of nvte_cublas_gemm_v2.

Computes:

  • D = alpha*AB if both bias and pre_gelu_out are empty tensors

  • D = alpha*AB + bias if pre_gelu_out is empty and bias is not empty

  • D = GELU(alpha*AB + bias) if both bias and pre_gelu_out are not empty tensors

Parameters:
  • A[in] The A matrix.

  • B[in] The B matrix.

  • D[inout] Output matrix.

  • bias[in] Bias tensor.

  • pre_gelu_out[inout] Output matrix before GELU activation.

  • transa[in] Whether A matrix is transposed.

  • transb[in] Whether B matrix is transposed.

  • grad[in] Whether this operation is part of the gradient computation.

  • workspace[out] Workspace tensor.

  • alpha[in] Scaling factor applied to the result of the GEMM

  • beta[in] Scaling factor applied to original value of D when accumulating into it. beta=0 means no accumulation.

  • use_split_accumulator[in] Whether to use split accumulator in the FP8 GEMM.

  • math_sm_count[in] Number of GPU SMs to use (default=0: use cuBLAS heuristics)

  • stream[in] CUDA stream used for the operation.

void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias, NVTETensor pre_gelu_out, bool transa, bool transb, bool grad, NVTETensor workspace, bool accumulate, bool use_split_accumulator, int math_sm_count, int m_split, int n_split, bool gemm_producer, const NVTETensor counter, cudaStream_t stream)

Compute matrix multiplication of 2 matrices with chunking and atomic counters.

Computes:

  • D = AB if both bias and pre_gelu_out are empty tensors

  • D = AB + bias if pre_gelu_out is empty and bias is not empty

  • D = GELU(AB + bias) if both bias and pre_gelu_out are not empty tensors

Warning

Cublas atomic gemm uses a beta API and is not tested for all use cases.

Parameters:
  • A[in] The A matrix.

  • B[in] The B matrix.

  • D[inout] Output matrix.

  • bias[in] Bias tensor.

  • pre_gelu_out[inout] Output matrix before GELU activation.

  • transa[in] Whether A matrix is transposed.

  • transb[in] Whether B matrix is transposed.

  • grad[in] Whether this operation is part of the gradient computation.

  • workspace[out] Workspace tensor.

  • accumulate[in] Whether to accumulate the result into the D matrix.

  • use_split_accumulator[in] Whether to use split accumulator in the FP8 GEMM.

  • math_sm_count[in] Number of GPU SMs to use (default=0: use cuBLAS heuristics)

  • m_split[in] Number of chunks/splits along m-dimension for Atomic GEMM.

  • n_split[in] Number of chunks/splits along n-dimension for Atomic GEMM.

  • gemm_producer[in] Whether Atomic GEMM is the producer or consumer.

  • counter[inout] counter[chunk_i]=0 indicates chunk_i has been produced.

  • stream[in] CUDA stream used for the operation.

void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, const NVTETensor *bias, NVTETensor *pre_gelu_out, const int num_gemms, bool transa, bool transb, bool grad, NVTETensor *workspace, bool accumulate, bool use_split_accumulator, int math_sm_count, cudaStream_t stream)

Compute multiple pairs of matrix multiplication, potentially fused with other operations, on multiple streams.

Computes:

  • D = AB if both bias and pre_gelu_out are empty tensors

  • D = AB + bias if pre_gelu_out is empty and bias is not empty

  • D = GELU(AB + bias) if both bias and pre_gelu_out are not empty tensors

Parameters:
  • A[in] The list of A matrices.

  • B[in] The list of B matrices.

  • D[inout] List of output matrices.

  • bias[in] List of bias tensors.

  • pre_gelu_out[inout] List of output matrix before GELU activation.

  • num_gemms[in] Number of GEMMs to compute.

  • transa[in] Whether A matrix is transposed.

  • transb[in] Whether B matrix is transposed.

  • grad[in] Whether this operation is part of the gradient computation.

  • workspace[out] List of workspace tensors.

  • accumulate[in] Whether to accumulate the result into the D matrix.

  • use_split_accumulator[in] Whether to use split accumulator in the FP8 GEMM.

  • math_sm_count[in] Number of GPU SMs to use (default=0: use cuBLAS heuristics)

  • stream[in] CUDA stream to wait on.

namespace transformer_engine

Namespace containing C++ API of Transformer Engine.

Functions

void nvte_cublas_handle_init()

TE/JAX cudaGraph requires the cuBLAS initialization to happen outside of the capturing region. This function is a helper to call cublasCreate() which allocate memory for the handle. The function will be called in the initialize phase of the related XLA custom calls.

struct MatmulConfigWrapper
#include <gemm.h>

C++ wrapper for NVTEMatmulConfig.

Public Functions

inline MatmulConfigWrapper()
MatmulConfigWrapper(const MatmulConfigWrapper&) = delete
MatmulConfigWrapper &operator=(const MatmulConfigWrapper&) = delete
inline MatmulConfigWrapper(MatmulConfigWrapper &&other)
inline MatmulConfigWrapper &operator=(MatmulConfigWrapper &&other)
inline ~MatmulConfigWrapper()
inline operator NVTEMatmulConfig() const noexcept

Get the underlying NVTEMatmulConfig.

Returns:

NVTEMatmulConfig held by this MatmulConfigWrapper.

inline void set_bias_tensor(NVTETensor bias_tensor)

Set bias tensor.

inline void set_dbias_tensor(NVTETensor dbias_tensor)

Set bias gradient tensor.

inline void set_with_gelu_epilogue(bool with_gelu_epilogue)

Set whether to compute GELU in GEMM epilogue.

inline void set_with_dgelu_epilogue(bool with_dgelu_epilogue)

Set whether to compute GELU backward in GEMM epilogue.

inline void set_epilogue_aux_tensor(NVTETensor epilogue_aux_tensor)

Set auxilliary tensor for GEMM epilogue.

inline void set_use_split_accumulator(bool use_split_accumulator)

Set whether to use split accumulator for FP8 GEMM.

inline void set_sm_count(int sm_count)

Set number of streaming multiprocessors to use in GEMM kernel.

Private Members

NVTEMatmulConfig config_ = nullptr

Wrapped NVTEMatmulConfig.