gemm.h

Functions for matrix multiplication.

Functions

void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias, NVTETensor pre_gelu_out, bool transa, bool transb, bool grad, NVTETensor workspace, bool accumulate, bool use_split_accumulator, cudaStream_t stream)

Compute matrix multiplication of 2 matrices, potentially fused with other operations.

Computes:

  • D = AB if both bias and pre_gelu_out are empty tensors

  • D = AB + bias if pre_gelu_out is empty and bias is not empty

  • D = GELU(AB + bias) if both bias and pre_gelu_out are not empty tensors

Parameters:
  • A[in] The A matrix.

  • B[in] The B matrix.

  • D[inout] Output matrix.

  • bias[in] Bias tensor.

  • pre_gelu_out[inout] Output matrix before GELU activation.

  • transa[in] Whether A matrix is transposed.

  • transb[in] Whether B matrix is transposed.

  • grad[in] Whether this operation is part of the gradient computation.

  • workspace[out] Workspace tensor.

  • accumulate[in] Whether to accumulate the result into the D matrix.

  • use_split_accumulator[in] Whether to use split accumulator in the FP8 GEMM.

  • stream[in] CUDA stream used for the operation.