softmax.h

Functions

void nvte_scaled_softmax_forward(const NVTETensor input, NVTETensor softmax_results, float scale_factor, cudaStream_t stream)

Compute scaled softmax activation on the input.

Parameters
  • input[in] Input tensor for softmax.

  • softmax_results[out] Output tensor.

  • scale_factor[in] Scalar for the input tensor.

  • stream[in] CUDA stream used for the operation.

void nvte_scaled_softmax_backward(const NVTETensor incoming_grads, const NVTETensor softmax_results, NVTETensor output_grads, float scale_factor, cudaStream_t stream)

Compute the backward of the scaled softmax activation.

  • incoming_grads is the input tensor containing the gradients received from the following layer.

  • softmax_results is the output tensor of the corresponding forward softmax operation.

  • output_grads is the output tensor containing the computed gradients.

Parameters
  • incoming_grads[in] Input gradient tensor for backward.

  • softmax_results[in] Output tensor of softmax forward.

  • output_grads[out] Output tensor.

  • scale_factor[in] Scalar for the output tensor.

  • stream[in] CUDA stream used for the operation.

void nvte_scaled_masked_softmax_forward(const NVTETensor input, const NVTETensor mask, NVTETensor softmax_results, float scale_factor, cudaStream_t stream)

Compute scaled masked softmax activation on the input.

Parameters
  • input[in] Input tensor for softmax.

  • mask[in] Mask for the input tensor.

  • softmax_results[out] Output tensor.

  • scale_factor[in] Scalar for the input tensor.

  • stream[in] CUDA stream used for the operation.

void nvte_scaled_masked_softmax_backward(const NVTETensor incoming_grads, const NVTETensor softmax_results, NVTETensor output_grads, float scale_factor, cudaStream_t stream)

Compute the backward of the scaled masked softmax activation.

  • incoming_grads is the input tensor containing the gradients received from the following layer.

  • softmax_results is the output tensor of the corresponding forward softmax operation.

  • output_grads is the output tensor containing the computed gradients.

Parameters
  • incoming_grads[in] Input gradient tensor for backward.

  • softmax_results[in] Output tensor of softmax forward.

  • output_grads[out] Output tensor.

  • scale_factor[in] Scalar for the output tensor.

  • stream[in] CUDA stream used for the operation.

void nvte_scaled_upper_triang_masked_softmax_forward(const NVTETensor input, NVTETensor softmax_results, float scale_factor, cudaStream_t stream)

Compute scaled softmax activation using a 2D upper triangular mask on the input.

Parameters
  • input[in] Input tensor for softmax.

  • softmax_results[out] Output tensor.

  • scale_factor[in] Scalar for the input tensor.

  • stream[in] CUDA stream used for the operation.

void nvte_scaled_upper_triang_masked_softmax_backward(const NVTETensor incoming_grads, const NVTETensor softmax_results, NVTETensor output_grads, float scale_factor, cudaStream_t stream)

Compute the backward of the scaled softmax activation using a 2D upper triangular mask.

  • incoming_grads is the input tensor containing the gradients received from the following layer.

  • softmax_results is the output tensor of the corresponding forward softmax operation.

  • output_grads is the output tensor containing the computed gradients.

Parameters
  • incoming_grads[in] Input gradient tensor for backward.

  • softmax_results[in] Output tensor of softmax forward.

  • output_grads[out] Output tensor.

  • scale_factor[in] Scalar for the output tensor.

  • stream[in] CUDA stream used for the operation.