rmsnorm.h

RMSNorm functions.

Functions

void nvte_rmsnorm_fwd(const NVTETensor x, const NVTETensor gamma, const float epsilon, NVTETensor z, NVTETensor rsigma, cudaStream_t stream, const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier)

Compute RMSNorm on the input.

Calling this function with workspace and barrier set to empty tensor will not perform the operation, but instead set the shape and type of the workspace and barrier tensors to the required values.

Parameters
  • x[in] Input tensor of shape [N, H].

  • gamma[in] Gamma tensor of shape [H].

  • epsilon[in] Value added to denominator for numerical stability.

  • z[inout] Output tensor of shape [N, H].

  • rsigma[out] Reciprocal of the root mean square of the input calculated over the last dimension. Shape: [N].

  • stream[in] CUDA stream used for the operation.

  • multiprocessorCount[in] Number of SMs in the device.

  • workspace[out] Workspace tensor.

  • barrier[out] Barrier tensor.

void nvte_rmsnorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor rsigma, const NVTETensor gamma, NVTETensor dx, NVTETensor dgamma, NVTETensor dgamma_part, cudaStream_t stream, const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier)

Compute backward of RMSNorm.

Calling this function with workspace, barrier, dgamma_part set to empty tensor will not perform the operation, but instead set the shape and type of these tensors to the required values.

Parameters
  • dz[in] Incoming gradient tensor of shape [N, H].

  • x[in] Forward input tensor of shape [N, H].

  • rsigma[in] Reciprocal of the root mean square of the input calculated over the last dimension. Shape: [N].

  • gamma[in] Gamma tensor of shape [H].

  • dx[out] Output gradient of shape [N, H].

  • dgamma[out] Gradient for gamma tensor of shape [H].

  • dgamma_part[out] Storage for partial gamma gradient.

  • stream[in] CUDA stream used for the operation.

  • multiprocessorCount[in] Number of SMs in the device.

  • workspace[out] Workspace tensor.

  • barrier[out] Barrier tensor.