gemm.h¶
Functions for matrix multiplication.
Functions
-
void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias, NVTETensor pre_gelu_out, bool transa, bool transb, bool grad, NVTETensor workspace, bool accumulate, bool use_split_accumulator, int math_sm_count, cudaStream_t stream)¶
Compute matrix multiplication of 2 matrices, potentially fused with other operations.
Computes:
D = AB
if bothbias
andpre_gelu_out
are empty tensorsD = AB + bias
ifpre_gelu_out
is empty andbias
is not emptyD = GELU(AB + bias)
if bothbias
andpre_gelu_out
are not empty tensors
- Parameters
A – [in] The A matrix.
B – [in] The B matrix.
D – [inout] Output matrix.
bias – [in] Bias tensor.
pre_gelu_out – [inout] Output matrix before GELU activation.
transa – [in] Whether A matrix is transposed.
transb – [in] Whether B matrix is transposed.
grad – [in] Whether this operation is part of the gradient computation.
workspace – [out] Workspace tensor.
accumulate – [in] Whether to accumulate the result into the D matrix.
use_split_accumulator – [in] Whether to use split accumulator in the FP8 GEMM.
math_sm_count – [in] Number of GPU SMs to use (default=0: use cuBLAS heuristics)
stream – [in] CUDA stream used for the operation.