gemm.h
Functions for matrix multiplication.
Typedefs
-
typedef void *NVTEMatmulConfig
Configuration for matrix multiplication.
Enums
-
enum NVTEMatmulConfigAttribute
Type of option for matrix multiplication.
Values:
-
enumerator kNVTEMatmulConfigBiasTensor
Bias tensor
If provided, the bias tensor is applied in the GEMM epilogue.
-
enumerator kNVTEMatmulConfigDBiasTensor
Bias gradient tensor
If provided, the bias gradient tensor will be filled in the GEMM epilogue.
-
enumerator kNVTEMatmulConfigWithGELUEpilogue
Whether to compute GELU in GEMM epilogue.
-
enumerator kNVTEMatmulConfigWithDGELUEpilogue
Whether to compute GELU backward in GEMM epilogue.
-
enumerator kNVTEMatmulConfigEpilogueAuxTensor
Auxilliary tensor for GEMM epilogue.
For GELU, this will be filled with the GELU input. For GELU backward, this is expected to already be filled with the GELU input.
-
enumerator kNVTEMatmulConfigUseSplitAccumulator
Whether to use split accumulator for FP8 GEMM.
-
enumerator kNVTEMatmulConfigSMCount
Number of streaming multiprocessors to use in GEMM kernel.
-
enumerator kNVTEMatmulConfigNumAttributes
-
enumerator kNVTEMatmulConfigBiasTensor
Functions
-
NVTEMatmulConfig nvte_create_matmul_config()
Create a matrix multiplication configuration.
-
void nvte_get_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr, void *buf, size_t size_in_bytes, size_t *size_written)
Query an option in matrix multiplication configuration.
- Parameters:
config – [in] Matrix multiplication configuration.
attr – [in] Option type.
buf – [out] Memory address to write option value. Ignored if NULL.
size_in_bytes – [in] Size of buf.
size_written – [out] Number of bytes that have been written to buf. If buf is NULL, then the number of bytes that would have been written.
-
void nvte_set_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr, const void *buf, size_t size_in_bytes)
Set an option in matrix multiplication configuration.
- Parameters:
config – [in] Matrix multiplication configuration.
attr – [in] Option type.
buf – [out] Memory address to read option value.
size_in_bytes – [in] Size of buf.
-
void nvte_destroy_matmul_config(NVTEMatmulConfig config)
Destroy a matrix multiplication configuration.
-
void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias, NVTETensor pre_gelu_out, bool transa, bool transb, bool grad, NVTETensor workspace, bool accumulate, bool use_split_accumulator, int math_sm_count, cudaStream_t stream)
Compute matrix multiplication of 2 matrices, potentially fused with other operations (deprecated).
This has been deprecated in favor of nvte_cublas_gemm_v2.
Computes:
D = AB
if bothbias
andpre_gelu_out
are empty tensorsD = AB + bias
ifpre_gelu_out
is empty andbias
is not emptyD = GELU(AB + bias)
if bothbias
andpre_gelu_out
are not empty tensors
- Parameters:
A – [in] The A matrix.
B – [in] The B matrix.
D – [inout] Output matrix.
bias – [in] Bias tensor.
pre_gelu_out – [inout] Output matrix before GELU activation.
transa – [in] Whether A matrix is transposed.
transb – [in] Whether B matrix is transposed.
grad – [in] Whether this operation is part of the gradient computation.
workspace – [out] Workspace tensor.
accumulate – [in] Whether to accumulate the result into the D matrix.
use_split_accumulator – [in] Whether to use split accumulator in the FP8 GEMM.
math_sm_count – [in] Number of GPU SMs to use (default=0: use cuBLAS heuristics)
stream – [in] CUDA stream used for the operation.
-
void nvte_cublas_gemm_v2(int transa, int transb, const float *alpha, const NVTETensor A, const NVTETensor B, const float *beta, const NVTETensor C, NVTETensor D, NVTETensor workspace, NVTEMatmulConfig config, cudaStream_t stream)
Compute matrix multiplication of 2 matrices, potentially fused with other operations.
Computes:
D = alpha * op(A) * op(B) + beta * C
- Parameters:
transa – [in] Whether to transpose A matrix.
transb – [in] Whether to transpose B matrix.
alpha – [in] Scaling factor applied to matmul output.
A – [in] A matrix.
B – [in] B matrix.
beta – [in] Scaling factor applied to C matrix.
C – [in] C matrix.
D – [out] Output matrix.
workspace – [in] Workspace tensor.
config – [in] Additional configuration.
stream – [in] CUDA stream used for the operation.
-
void nvte_cublas_gemm_scaled(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias, NVTETensor pre_gelu_out, bool transa, bool transb, bool grad, NVTETensor workspace, float alpha, float beta, bool use_split_accumulator, int math_sm_count, cudaStream_t stream)
Compute matrix multiplication of 2 matrices, potentially fused with other operations, allowing for using a scaling factor for the GEMM result and the accumulation input (deprecated)
This has been deprecated in favor of nvte_cublas_gemm_v2.
Computes:
D = alpha*AB
if bothbias
andpre_gelu_out
are empty tensorsD = alpha*AB + bias
ifpre_gelu_out
is empty andbias
is not emptyD = GELU(alpha*AB + bias)
if bothbias
andpre_gelu_out
are not empty tensors
- Parameters:
A – [in] The A matrix.
B – [in] The B matrix.
D – [inout] Output matrix.
bias – [in] Bias tensor.
pre_gelu_out – [inout] Output matrix before GELU activation.
transa – [in] Whether A matrix is transposed.
transb – [in] Whether B matrix is transposed.
grad – [in] Whether this operation is part of the gradient computation.
workspace – [out] Workspace tensor.
alpha – [in] Scaling factor applied to the result of the GEMM
beta – [in] Scaling factor applied to original value of D when accumulating into it. beta=0 means no accumulation.
use_split_accumulator – [in] Whether to use split accumulator in the FP8 GEMM.
math_sm_count – [in] Number of GPU SMs to use (default=0: use cuBLAS heuristics)
stream – [in] CUDA stream used for the operation.
-
void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias, NVTETensor pre_gelu_out, bool transa, bool transb, bool grad, NVTETensor workspace, bool accumulate, bool use_split_accumulator, int math_sm_count, int m_split, int n_split, bool gemm_producer, const NVTETensor counter, cudaStream_t stream)
Compute matrix multiplication of 2 matrices with chunking and atomic counters.
Computes:
D = AB
if bothbias
andpre_gelu_out
are empty tensorsD = AB + bias
ifpre_gelu_out
is empty andbias
is not emptyD = GELU(AB + bias)
if bothbias
andpre_gelu_out
are not empty tensors
Warning
Cublas atomic gemm uses a beta API and is not tested for all use cases.
- Parameters:
A – [in] The A matrix.
B – [in] The B matrix.
D – [inout] Output matrix.
bias – [in] Bias tensor.
pre_gelu_out – [inout] Output matrix before GELU activation.
transa – [in] Whether A matrix is transposed.
transb – [in] Whether B matrix is transposed.
grad – [in] Whether this operation is part of the gradient computation.
workspace – [out] Workspace tensor.
accumulate – [in] Whether to accumulate the result into the D matrix.
use_split_accumulator – [in] Whether to use split accumulator in the FP8 GEMM.
math_sm_count – [in] Number of GPU SMs to use (default=0: use cuBLAS heuristics)
m_split – [in] Number of chunks/splits along m-dimension for Atomic GEMM.
n_split – [in] Number of chunks/splits along n-dimension for Atomic GEMM.
gemm_producer – [in] Whether Atomic GEMM is the producer or consumer.
counter – [inout] counter[chunk_i]=0 indicates chunk_i has been produced.
stream – [in] CUDA stream used for the operation.
-
void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, const NVTETensor *bias, NVTETensor *pre_gelu_out, const int num_gemms, bool transa, bool transb, bool grad, NVTETensor *workspace, bool accumulate, bool use_split_accumulator, int math_sm_count, cudaStream_t stream)
Compute multiple pairs of matrix multiplication, potentially fused with other operations, on multiple streams.
Computes:
D = AB
if bothbias
andpre_gelu_out
are empty tensorsD = AB + bias
ifpre_gelu_out
is empty andbias
is not emptyD = GELU(AB + bias)
if bothbias
andpre_gelu_out
are not empty tensors
- Parameters:
A – [in] The list of A matrices.
B – [in] The list of B matrices.
D – [inout] List of output matrices.
bias – [in] List of bias tensors.
pre_gelu_out – [inout] List of output matrix before GELU activation.
num_gemms – [in] Number of GEMMs to compute.
transa – [in] Whether A matrix is transposed.
transb – [in] Whether B matrix is transposed.
grad – [in] Whether this operation is part of the gradient computation.
workspace – [out] List of workspace tensors.
accumulate – [in] Whether to accumulate the result into the D matrix.
use_split_accumulator – [in] Whether to use split accumulator in the FP8 GEMM.
math_sm_count – [in] Number of GPU SMs to use (default=0: use cuBLAS heuristics)
stream – [in] CUDA stream to wait on.
-
namespace transformer_engine
Namespace containing C++ API of Transformer Engine.
Functions
-
void nvte_cublas_handle_init()
TE/JAX cudaGraph requires the cuBLAS initialization to happen outside of the capturing region. This function is a helper to call cublasCreate() which allocate memory for the handle. The function will be called in the initialize phase of the related XLA custom calls.
-
struct MatmulConfigWrapper
- #include <gemm.h>
C++ wrapper for NVTEMatmulConfig.
Public Functions
-
inline MatmulConfigWrapper()
-
MatmulConfigWrapper(const MatmulConfigWrapper&) = delete
-
MatmulConfigWrapper &operator=(const MatmulConfigWrapper&) = delete
-
inline MatmulConfigWrapper(MatmulConfigWrapper &&other)
-
inline MatmulConfigWrapper &operator=(MatmulConfigWrapper &&other)
-
inline ~MatmulConfigWrapper()
-
inline operator NVTEMatmulConfig() const noexcept
Get the underlying NVTEMatmulConfig.
- Returns:
NVTEMatmulConfig held by this MatmulConfigWrapper.
-
inline void set_bias_tensor(NVTETensor bias_tensor)
Set bias tensor.
-
inline void set_dbias_tensor(NVTETensor dbias_tensor)
Set bias gradient tensor.
-
inline void set_with_gelu_epilogue(bool with_gelu_epilogue)
Set whether to compute GELU in GEMM epilogue.
-
inline void set_with_dgelu_epilogue(bool with_dgelu_epilogue)
Set whether to compute GELU backward in GEMM epilogue.
-
inline void set_epilogue_aux_tensor(NVTETensor epilogue_aux_tensor)
Set auxilliary tensor for GEMM epilogue.
-
inline void set_use_split_accumulator(bool use_split_accumulator)
Set whether to use split accumulator for FP8 GEMM.
-
inline void set_sm_count(int sm_count)
Set number of streaming multiprocessors to use in GEMM kernel.
Private Members
-
NVTEMatmulConfig config_ = nullptr
Wrapped NVTEMatmulConfig.
-
inline MatmulConfigWrapper()
-
void nvte_cublas_handle_init()