fused_rope.h
Functions
-
void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens, const NVTETensor freqs, const NVTETensor start_positions, NVTETensor output, const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, const int cp_rank, const int s, const int b, const int h, const int d, const int d2, const int stride_s_or_t, const int stride_b, const int stride_h, const int stride_d, cudaStream_t stream)
Apply rotary positional embedding to the input tensor.
- Parameters:
input – [in] Input tensor for fused rope.
cu_seqlens – [in] The cumulative sum of sequence lengths tensor. (Required for the thd format, empty tensor for other formats)
freqs – [in] The freqs tensor.
start_positions – [in] The beginning offsets for applying RoPE embeddings.
output – [out] Output tensor.
qkv_format – [in] QKV format.
interleaved – [in] Whether to use interleaved rotary position embedding.
cp_size – [in] Context parallel world size.
cp_rank – [in] Context parallel rank.
s – [in] Length of the s dimension of input.
b – [in] Length of the b dimension of input.
h – [in] Length of the h dimension of input.
d – [in] Length of the d dimension of input.
d2 – [in] Length of the d dimension of freqs.
stride_s_or_t – [in] Stride of the s (sbhd/bshd)/t (thd) dimension of input.
stride_b – [in] Stride of the b dimension of input. (0 for thd).
stride_h – [in] Stride of the h dimension of input.
stride_d – [in] Stride of the d dimension of input.
stream – [in] CUDA stream used for the operation.
-
void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens, const NVTETensor freqs, NVTETensor input_grads, const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, const int cp_rank, const int s, const int b, const int h, const int d, const int d2, const int stride_s_or_t, const int stride_b, const int stride_h, const int stride_d, cudaStream_t stream)
Compute the backward of the fused rope.
- Parameters:
output_grads – [in] Incoming gradient tensor for backward.
cu_seqlens – [in] The cumulative sum of sequence lengths tensor. (Required for the thd format, empty tensor for other formats)
freqs – [in] The freqs tensor.
input_grads – [out] Input gradient tensor to calculate.
qkv_format – [in] QKV format.
interleaved – [in] Whether to use interleaved rotary position embedding.
cp_size – [in] Context parallel world size.
cp_rank – [in] Context parallel rank.
s – [in] Length of the s dimension of output_grads.
b – [in] Length of the b dimension of output_grads.
h – [in] Length of the h dimension of output_grads.
d – [in] Length of the d dimension of output_grads.
d2 – [in] Length of the d dimension of freqs.
stride_s_or_t – [in] Stride of the s (sbhd/bshd)/t (thd) dimension of output_grads.
stride_b – [in] Stride of the b dimension of output_grads. (0 for thd).
stride_h – [in] Stride of the h dimension of output_grads.
stride_d – [in] Stride of the d dimension of output_grads.
stream – [in] CUDA stream used for the operation.