fused_rope.h

Functions

void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens, const NVTETensor freqs, const NVTETensor start_positions, NVTETensor output, const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, const int cp_rank, const int s, const int b, const int h, const int d, const int d2, const int stride_s_or_t, const int stride_b, const int stride_h, const int stride_d, cudaStream_t stream)

Apply rotary positional embedding to the input tensor.

Parameters:
  • input[in] Input tensor for fused rope.

  • cu_seqlens[in] The cumulative sum of sequence lengths tensor. (Required for the thd format, empty tensor for other formats)

  • freqs[in] The freqs tensor.

  • start_positions[in] The beginning offsets for applying RoPE embeddings.

  • output[out] Output tensor.

  • qkv_format[in] QKV format.

  • interleaved[in] Whether to use interleaved rotary position embedding.

  • cp_size[in] Context parallel world size.

  • cp_rank[in] Context parallel rank.

  • s[in] Length of the s dimension of input.

  • b[in] Length of the b dimension of input.

  • h[in] Length of the h dimension of input.

  • d[in] Length of the d dimension of input.

  • d2[in] Length of the d dimension of freqs.

  • stride_s_or_t[in] Stride of the s (sbhd/bshd)/t (thd) dimension of input.

  • stride_b[in] Stride of the b dimension of input. (0 for thd).

  • stride_h[in] Stride of the h dimension of input.

  • stride_d[in] Stride of the d dimension of input.

  • stream[in] CUDA stream used for the operation.

void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens, const NVTETensor freqs, NVTETensor input_grads, const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, const int cp_rank, const int s, const int b, const int h, const int d, const int d2, const int stride_s_or_t, const int stride_b, const int stride_h, const int stride_d, cudaStream_t stream)

Compute the backward of the fused rope.

Parameters:
  • output_grads[in] Incoming gradient tensor for backward.

  • cu_seqlens[in] The cumulative sum of sequence lengths tensor. (Required for the thd format, empty tensor for other formats)

  • freqs[in] The freqs tensor.

  • input_grads[out] Input gradient tensor to calculate.

  • qkv_format[in] QKV format.

  • interleaved[in] Whether to use interleaved rotary position embedding.

  • cp_size[in] Context parallel world size.

  • cp_rank[in] Context parallel rank.

  • s[in] Length of the s dimension of output_grads.

  • b[in] Length of the b dimension of output_grads.

  • h[in] Length of the h dimension of output_grads.

  • d[in] Length of the d dimension of output_grads.

  • d2[in] Length of the d dimension of freqs.

  • stride_s_or_t[in] Stride of the s (sbhd/bshd)/t (thd) dimension of output_grads.

  • stride_b[in] Stride of the b dimension of output_grads. (0 for thd).

  • stride_h[in] Stride of the h dimension of output_grads.

  • stride_d[in] Stride of the d dimension of output_grads.

  • stream[in] CUDA stream used for the operation.