fused_rope.h
Functions
-
void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens, const NVTETensor freqs, const NVTETensor start_positions, NVTETensor output, const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, const int cp_rank, const int s, const int b, const int h, const int d, const int d2, const int stride_s_or_t, const int stride_b, const int stride_h, const int stride_d, cudaStream_t stream)
Apply rotary positional embedding to the input tensor.
- Parameters:
input – [in] Input tensor for fused rope.
cu_seqlens – [in] The cumulative sum of sequence lengths tensor. (Required for the thd format, empty tensor for other formats)
freqs – [in] The freqs tensor.
start_positions – [in] The beginning offsets for applying RoPE embeddings.
output – [out] Output tensor.
qkv_format – [in] QKV format.
interleaved – [in] Whether to use interleaved rotary position embedding.
cp_size – [in] Context parallel world size.
cp_rank – [in] Context parallel rank.
s – [in] Length of the s dimension of input.
b – [in] Length of the b dimension of input.
h – [in] Length of the h dimension of input.
d – [in] Length of the d dimension of input.
d2 – [in] Length of the d dimension of freqs.
stride_s_or_t – [in] Stride of the s (sbhd/bshd)/t (thd) dimension of input.
stride_b – [in] Stride of the b dimension of input. (0 for thd).
stride_h – [in] Stride of the h dimension of input.
stride_d – [in] Stride of the d dimension of input.
stream – [in] CUDA stream used for the operation.
-
void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens, const NVTETensor freqs, NVTETensor input_grads, const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, const int cp_rank, const int s, const int b, const int h, const int d, const int d2, const int stride_s_or_t, const int stride_b, const int stride_h, const int stride_d, cudaStream_t stream)
Compute the backward of the fused rope.
- Parameters:
output_grads – [in] Incoming gradient tensor for backward.
cu_seqlens – [in] The cumulative sum of sequence lengths tensor. (Required for the thd format, empty tensor for other formats)
freqs – [in] The freqs tensor.
input_grads – [out] Input gradient tensor to calculate.
qkv_format – [in] QKV format.
interleaved – [in] Whether to use interleaved rotary position embedding.
cp_size – [in] Context parallel world size.
cp_rank – [in] Context parallel rank.
s – [in] Length of the s dimension of output_grads.
b – [in] Length of the b dimension of output_grads.
h – [in] Length of the h dimension of output_grads.
d – [in] Length of the d dimension of output_grads.
d2 – [in] Length of the d dimension of freqs.
stride_s_or_t – [in] Stride of the s (sbhd/bshd)/t (thd) dimension of output_grads.
stride_b – [in] Stride of the b dimension of output_grads. (0 for thd).
stride_h – [in] Stride of the h dimension of output_grads.
stride_d – [in] Stride of the d dimension of output_grads.
stream – [in] CUDA stream used for the operation.
-
void nvte_fused_qkv_rope_forward(const NVTETensor qkv_input, const NVTETensor q_freqs, const NVTETensor k_freqs, const NVTETensor start_positions, NVTETensor q_out, NVTETensor k_out, NVTETensor v_out, const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, const int cp_rank, const int s, const int b, const int h, const int d, const int d2, const int qkv_split_arg_list_0, const int qkv_split_arg_list_1, const int qkv_split_arg_list_2, cudaStream_t stream)
Apply rotary positional embedding to the combined QKV input tensor.
- Parameters:
qkv_input – [in] Combined QKV input tensor for fused rope.
q_freqs – [in] The freqs tensor for Q.
k_freqs – [in] The freqs tensor for K.
start_positions – [in] The beginning offsets for applying RoPE embeddings.
q_out – [out] Output tensor for Q.
k_out – [out] Output tensor for K.
v_out – [out] Output tensor for V.
qkv_format – [in] QKV format.
interleaved – [in] Whether to use interleaved rotary position embedding.
cp_size – [in] Context parallel world size.
cp_rank – [in] Context parallel rank.
s – [in] Length of the s dimension of input.
b – [in] Length of the b dimension of input.
h – [in] Length of the h dimension of input.
d – [in] Length of the d dimension of input.
d2 – [in] Length of the d dimension of freqs.
qkv_split_arg_list_0 – [in] The hidden size for Q.
qkv_split_arg_list_1 – [in] The hidden size for K.
qkv_split_arg_list_2 – [in] The hidden size for V.
stream – [in] CUDA stream used for the operation.
-
void nvte_fused_qkv_rope_backward(const NVTETensor q_grad_out, const NVTETensor k_grad_out, const NVTETensor v_grad_out, const NVTETensor q_freqs, const NVTETensor k_freqs, NVTETensor qkv_grad_input, const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, const int cp_rank, const int s, const int b, const int h, const int d, const int d2, const int qkv_split_arg_list_0, const int qkv_split_arg_list_1, const int qkv_split_arg_list_2, cudaStream_t stream)
Compute the backward of the fused qkv rope.
- Parameters:
q_grad_out – [in] Incoming gradient tensor for Q.
k_grad_out – [in] Incoming gradient tensor for K.
v_grad_out – [in] Incoming gradient tensor for V.
q_freqs – [in] The freqs tensor for Q.
k_freqs – [in] The freqs tensor for K.
qkv_grad_input – [out] Input gradient tensor to calculate.
qkv_format – [in] QKV format.
interleaved – [in] Whether to use interleaved rotary position embedding.
cp_size – [in] Context parallel world size.
cp_rank – [in] Context parallel rank.
s – [in] Length of the s dimension of input.
b – [in] Length of the b dimension of input.
h – [in] Length of the h dimension of input.
d – [in] Length of the d dimension of input.
d2 – [in] Length of the d dimension of freqs.
qkv_split_arg_list_0 – [in] The hidden size for Q.
qkv_split_arg_list_1 – [in] The hidden size for K.
qkv_split_arg_list_2 – [in] The hidden size for V.
stream – [in] CUDA stream used for the operation.