fused_rope.h
Functions
-
void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, NVTETensor output, const int s, const int b, const int h, const int d, const int d2, const int stride_s, const int stride_b, const int stride_h, const int stride_d, const int o_stride_s, const int o_stride_b, const int o_stride_h, const int o_stride_d, cudaStream_t stream)
Apply rotary positional embedding to the input tensor.
- Parameters:
input – [in] Input tensor for fused rope.
freqs – [in] The freqs tensor.
output – [out] Output tensor.
s – [in] Length of the s dimension of input.
b – [in] Length of the b dimension of input.
h – [in] Length of the h dimension of input.
d – [in] Length of the d dimension of input.
d2 – [in] Length of the d dimension of freqs.
stride_s – [in] Stride of the s dimension of input.
stride_b – [in] Stride of the b dimension of input.
stride_h – [in] Stride of the h dimension of input.
stride_d – [in] Stride of the d dimension of input.
o_stride_s – [in] Stride of the s dimension of output.
o_stride_b – [in] Stride of the b dimension of output.
o_stride_h – [in] Stride of the h dimension of output.
o_stride_d – [in] Stride of the d dimension of output.
stream – [in] CUDA stream used for the operation.
-
void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor freqs, NVTETensor input_grads, const int s, const int b, const int h, const int d, const int d2, const int stride_s, const int stride_b, const int stride_h, const int stride_d, const int o_stride_s, const int o_stride_b, const int o_stride_h, const int o_stride_d, cudaStream_t stream)
Compute the backward of the fused rope.
- Parameters:
output_grads – [in] Incoming gradient tensor for backward.
freqs – [in] The freqs tensor.
input_grads – [out] Input gradient tensor to calculate.
s – [in] Length of the s dimension of output_grads.
b – [in] Length of the b dimension of output_grads.
h – [in] Length of the h dimension of output_grads.
d – [in] Length of the d dimension of output_grads.
d2 – [in] Length of the d dimension of freqs.
stride_s – [in] Stride of the s dimension of output_grads.
stride_b – [in] Stride of the b dimension of output_grads.
stride_h – [in] Stride of the h dimension of output_grads.
stride_d – [in] Stride of the d dimension of output_grads.
o_stride_s – [in] Stride of the s dimension of input_grads.
o_stride_b – [in] Stride of the b dimension of input_grads.
o_stride_h – [in] Stride of the h dimension of input_grads.
o_stride_d – [in] Stride of the d dimension of input_grads.
stream – [in] CUDA stream used for the operation.
-
void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seqlens, const NVTETensor freqs, NVTETensor output, const int cp_size, const int cp_rank, const int max_s, const int b, const int h, const int d, const int d2, const int stride_t, const int stride_h, const int stride_d, const int o_stride_t, const int o_stride_h, const int o_stride_d, cudaStream_t stream)
Apply rotary positional embedding to the input tensor in thd format.
- Parameters:
input – [in] Input tensor for fused rope.
cu_seqlens – [in] The cumulative sum of sequence lengths tensor.
freqs – [in] The freqs tensor.
output – [out] Output tensor.
cp_size – [in] Context parallel world size.
cp_rank – [in] Context parallel rank.
max_s – [in] Max sequence length.
b – [in] Batch size.
h – [in] Length of the h dimension of input.
d – [in] Length of the d dimension of input.
d2 – [in] Length of the d dimension of freqs.
stride_t – [in] Stride of the t dimension of input.
stride_h – [in] Stride of the h dimension of input.
stride_d – [in] Stride of the d dimension of input.
o_stride_t – [in] Stride of the t dimension of output.
o_stride_h – [in] Stride of the h dimension of output.
o_stride_d – [in] Stride of the d dimension of output.
stream – [in] CUDA stream used for the operation.
-
void nvte_fused_rope_thd_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens, const NVTETensor freqs, NVTETensor input_grads, const int cp_size, const int cp_rank, const int max_s, const int b, const int h, const int d, const int d2, const int stride_t, const int stride_h, const int stride_d, const int o_stride_t, const int o_stride_h, const int o_stride_d, cudaStream_t stream)
Compute the backward of the fused rope in thd format.
- Parameters:
output_grads – [in] Incoming gradient tensor for backward.
cu_seqlens – [in] The cumulative sum of sequence lengths tensor.
freqs – [in] The freqs tensor.
input_grads – [out] Input gradient to calculate.
cp_size – [in] Context parallel world size.
cp_rank – [in] Context parallel rank.
max_s – [in] Max sequence length.
b – [in] Batch size.
h – [in] Length of the h dimension of output_grads.
d – [in] Length of the d dimension of output_grads.
d2 – [in] Length of the d dimension of freqs.
stride_t – [in] Stride of the t dimension of output_grads.
stride_h – [in] Stride of the h dimension of output_grads.
stride_d – [in] Stride of the d dimension of output_grads.
o_stride_t – [in] Stride of the t dimension of input_grads.
o_stride_h – [in] Stride of the h dimension of input_grads.
o_stride_d – [in] Stride of the d dimension of input_grads.
stream – [in] CUDA stream used for the operation.