fused_rope.h

Functions

void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, NVTETensor output, const int s, const int b, const int h, const int d, const int d2, const int stride_s, const int stride_b, const int stride_h, const int stride_d, const int o_stride_s, const int o_stride_b, const int o_stride_h, const int o_stride_d, cudaStream_t stream)

Apply rotary positional embedding to the input tensor.

Parameters:
  • input[in] Input tensor for fused rope.

  • freqs[in] The freqs tensor.

  • output[out] Output tensor.

  • s[in] Length of the s dimension of input.

  • b[in] Length of the b dimension of input.

  • h[in] Length of the h dimension of input.

  • d[in] Length of the d dimension of input.

  • d2[in] Length of the d dimension of freqs.

  • stride_s[in] Stride of the s dimension of input.

  • stride_b[in] Stride of the b dimension of input.

  • stride_h[in] Stride of the h dimension of input.

  • stride_d[in] Stride of the d dimension of input.

  • o_stride_s[in] Stride of the s dimension of output.

  • o_stride_b[in] Stride of the b dimension of output.

  • o_stride_h[in] Stride of the h dimension of output.

  • o_stride_d[in] Stride of the d dimension of output.

  • stream[in] CUDA stream used for the operation.

void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor freqs, NVTETensor input_grads, const int s, const int b, const int h, const int d, const int d2, const int stride_s, const int stride_b, const int stride_h, const int stride_d, const int o_stride_s, const int o_stride_b, const int o_stride_h, const int o_stride_d, cudaStream_t stream)

Compute the backward of the fused rope.

Parameters:
  • output_grads[in] Incoming gradient tensor for backward.

  • freqs[in] The freqs tensor.

  • input_grads[out] Input gradient tensor to calculate.

  • s[in] Length of the s dimension of output_grads.

  • b[in] Length of the b dimension of output_grads.

  • h[in] Length of the h dimension of output_grads.

  • d[in] Length of the d dimension of output_grads.

  • d2[in] Length of the d dimension of freqs.

  • stride_s[in] Stride of the s dimension of output_grads.

  • stride_b[in] Stride of the b dimension of output_grads.

  • stride_h[in] Stride of the h dimension of output_grads.

  • stride_d[in] Stride of the d dimension of output_grads.

  • o_stride_s[in] Stride of the s dimension of input_grads.

  • o_stride_b[in] Stride of the b dimension of input_grads.

  • o_stride_h[in] Stride of the h dimension of input_grads.

  • o_stride_d[in] Stride of the d dimension of input_grads.

  • stream[in] CUDA stream used for the operation.

void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seqlens, const NVTETensor freqs, NVTETensor output, const int cp_size, const int cp_rank, const int max_s, const int b, const int h, const int d, const int d2, const int stride_t, const int stride_h, const int stride_d, const int o_stride_t, const int o_stride_h, const int o_stride_d, cudaStream_t stream)

Apply rotary positional embedding to the input tensor in thd format.

Parameters:
  • input[in] Input tensor for fused rope.

  • cu_seqlens[in] The cumulative sum of sequence lengths tensor.

  • freqs[in] The freqs tensor.

  • output[out] Output tensor.

  • cp_size[in] Context parallel world size.

  • cp_rank[in] Context parallel rank.

  • max_s[in] Max sequence length.

  • b[in] Batch size.

  • h[in] Length of the h dimension of input.

  • d[in] Length of the d dimension of input.

  • d2[in] Length of the d dimension of freqs.

  • stride_t[in] Stride of the t dimension of input.

  • stride_h[in] Stride of the h dimension of input.

  • stride_d[in] Stride of the d dimension of input.

  • o_stride_t[in] Stride of the t dimension of output.

  • o_stride_h[in] Stride of the h dimension of output.

  • o_stride_d[in] Stride of the d dimension of output.

  • stream[in] CUDA stream used for the operation.

void nvte_fused_rope_thd_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens, const NVTETensor freqs, NVTETensor input_grads, const int cp_size, const int cp_rank, const int max_s, const int b, const int h, const int d, const int d2, const int stride_t, const int stride_h, const int stride_d, const int o_stride_t, const int o_stride_h, const int o_stride_d, cudaStream_t stream)

Compute the backward of the fused rope in thd format.

Parameters:
  • output_grads[in] Incoming gradient tensor for backward.

  • cu_seqlens[in] The cumulative sum of sequence lengths tensor.

  • freqs[in] The freqs tensor.

  • input_grads[out] Input gradient to calculate.

  • cp_size[in] Context parallel world size.

  • cp_rank[in] Context parallel rank.

  • max_s[in] Max sequence length.

  • b[in] Batch size.

  • h[in] Length of the h dimension of output_grads.

  • d[in] Length of the d dimension of output_grads.

  • d2[in] Length of the d dimension of freqs.

  • stride_t[in] Stride of the t dimension of output_grads.

  • stride_h[in] Stride of the h dimension of output_grads.

  • stride_d[in] Stride of the d dimension of output_grads.

  • o_stride_t[in] Stride of the t dimension of input_grads.

  • o_stride_h[in] Stride of the h dimension of input_grads.

  • o_stride_d[in] Stride of the d dimension of input_grads.

  • stream[in] CUDA stream used for the operation.