fused_rope.h

Functions

void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens, const NVTETensor freqs, const NVTETensor start_positions, NVTETensor output, const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, const int cp_rank, const int s, const int b, const int h, const int d, const int d2, const int stride_s_or_t, const int stride_b, const int stride_h, const int stride_d, cudaStream_t stream)

Apply rotary positional embedding to the input tensor.

Parameters:
  • input[in] Input tensor for fused rope.

  • cu_seqlens[in] The cumulative sum of sequence lengths tensor. (Required for the thd format, empty tensor for other formats)

  • freqs[in] The freqs tensor.

  • start_positions[in] The beginning offsets for applying RoPE embeddings.

  • output[out] Output tensor.

  • qkv_format[in] QKV format.

  • interleaved[in] Whether to use interleaved rotary position embedding.

  • cp_size[in] Context parallel world size.

  • cp_rank[in] Context parallel rank.

  • s[in] Length of the s dimension of input.

  • b[in] Length of the b dimension of input.

  • h[in] Length of the h dimension of input.

  • d[in] Length of the d dimension of input.

  • d2[in] Length of the d dimension of freqs.

  • stride_s_or_t[in] Stride of the s (sbhd/bshd)/t (thd) dimension of input.

  • stride_b[in] Stride of the b dimension of input. (0 for thd).

  • stride_h[in] Stride of the h dimension of input.

  • stride_d[in] Stride of the d dimension of input.

  • stream[in] CUDA stream used for the operation.

void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens, const NVTETensor freqs, NVTETensor input_grads, const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, const int cp_rank, const int s, const int b, const int h, const int d, const int d2, const int stride_s_or_t, const int stride_b, const int stride_h, const int stride_d, cudaStream_t stream)

Compute the backward of the fused rope.

Parameters:
  • output_grads[in] Incoming gradient tensor for backward.

  • cu_seqlens[in] The cumulative sum of sequence lengths tensor. (Required for the thd format, empty tensor for other formats)

  • freqs[in] The freqs tensor.

  • input_grads[out] Input gradient tensor to calculate.

  • qkv_format[in] QKV format.

  • interleaved[in] Whether to use interleaved rotary position embedding.

  • cp_size[in] Context parallel world size.

  • cp_rank[in] Context parallel rank.

  • s[in] Length of the s dimension of output_grads.

  • b[in] Length of the b dimension of output_grads.

  • h[in] Length of the h dimension of output_grads.

  • d[in] Length of the d dimension of output_grads.

  • d2[in] Length of the d dimension of freqs.

  • stride_s_or_t[in] Stride of the s (sbhd/bshd)/t (thd) dimension of output_grads.

  • stride_b[in] Stride of the b dimension of output_grads. (0 for thd).

  • stride_h[in] Stride of the h dimension of output_grads.

  • stride_d[in] Stride of the d dimension of output_grads.

  • stream[in] CUDA stream used for the operation.

void nvte_fused_qkv_rope_forward(const NVTETensor qkv_input, const NVTETensor q_freqs, const NVTETensor k_freqs, const NVTETensor start_positions, NVTETensor q_out, NVTETensor k_out, NVTETensor v_out, const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, const int cp_rank, const int s, const int b, const int h, const int d, const int d2, const int qkv_split_arg_list_0, const int qkv_split_arg_list_1, const int qkv_split_arg_list_2, cudaStream_t stream)

Apply rotary positional embedding to the combined QKV input tensor.

Parameters:
  • qkv_input[in] Combined QKV input tensor for fused rope.

  • q_freqs[in] The freqs tensor for Q.

  • k_freqs[in] The freqs tensor for K.

  • start_positions[in] The beginning offsets for applying RoPE embeddings.

  • q_out[out] Output tensor for Q.

  • k_out[out] Output tensor for K.

  • v_out[out] Output tensor for V.

  • qkv_format[in] QKV format.

  • interleaved[in] Whether to use interleaved rotary position embedding.

  • cp_size[in] Context parallel world size.

  • cp_rank[in] Context parallel rank.

  • s[in] Length of the s dimension of input.

  • b[in] Length of the b dimension of input.

  • h[in] Length of the h dimension of input.

  • d[in] Length of the d dimension of input.

  • d2[in] Length of the d dimension of freqs.

  • qkv_split_arg_list_0[in] The hidden size for Q.

  • qkv_split_arg_list_1[in] The hidden size for K.

  • qkv_split_arg_list_2[in] The hidden size for V.

  • stream[in] CUDA stream used for the operation.

void nvte_fused_qkv_rope_backward(const NVTETensor q_grad_out, const NVTETensor k_grad_out, const NVTETensor v_grad_out, const NVTETensor q_freqs, const NVTETensor k_freqs, NVTETensor qkv_grad_input, const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, const int cp_rank, const int s, const int b, const int h, const int d, const int d2, const int qkv_split_arg_list_0, const int qkv_split_arg_list_1, const int qkv_split_arg_list_2, cudaStream_t stream)

Compute the backward of the fused qkv rope.

Parameters:
  • q_grad_out[in] Incoming gradient tensor for Q.

  • k_grad_out[in] Incoming gradient tensor for K.

  • v_grad_out[in] Incoming gradient tensor for V.

  • q_freqs[in] The freqs tensor for Q.

  • k_freqs[in] The freqs tensor for K.

  • qkv_grad_input[out] Input gradient tensor to calculate.

  • qkv_format[in] QKV format.

  • interleaved[in] Whether to use interleaved rotary position embedding.

  • cp_size[in] Context parallel world size.

  • cp_rank[in] Context parallel rank.

  • s[in] Length of the s dimension of input.

  • b[in] Length of the b dimension of input.

  • h[in] Length of the h dimension of input.

  • d[in] Length of the d dimension of input.

  • d2[in] Length of the d dimension of freqs.

  • qkv_split_arg_list_0[in] The hidden size for Q.

  • qkv_split_arg_list_1[in] The hidden size for K.

  • qkv_split_arg_list_2[in] The hidden size for V.

  • stream[in] CUDA stream used for the operation.