core.fusions.fused_mla_yarn_rope_apply#
Module Contents#
Classes#
Autograd function for applying YARN RoPE to MLA’s query. |
|
Autograd function for applying YARN RoPE to MLA’s key and value. |
Functions#
Triton kernel of the forward pass for applying YARN RoPE to MLA’s query. This kernel inplace modifies the input tensor Q. |
|
Triton kernel of the backward pass for applying YARN RoPE to MLA’s query. This kernel inplace modifies the input tensor DO. |
|
Fused function for applying YARN RoPE to MLA’s query. This function inplace modifies the input tensor t. Along the last dimension of t, the last emb_dim elements are applied with RoPE. The first qk_head_dim elements are not modified. It is an experimental feature and may change in future versions. It supports both sbhd and thd input formats. |
|
Triton kernel of the forward pass for applying YARN RoPE to MLA’s key and value. It splits the input tensor KV into key and value, and concatenates the processed RoPE to the key. |
|
Triton kernel of the backward pass for applying YARN RoPE to MLA’s key and value. |
|
Fused function for applying YARN RoPE to MLA’s key and value. It splits the input tensor kv into key and value, and concatenates the processed RoPE to the key. |
API#
- core.fusions.fused_mla_yarn_rope_apply._get_thd_token_idx(cu_seqlens, pid_m, seq_num, cp_rank, cp_size)#
- core.fusions.fused_mla_yarn_rope_apply.rotary_fwd_q_kernel(
- Q,
- COS,
- SIN,
- qk_head_dim,
- emb_dim: triton.language.constexpr,
- head_num: triton.language.constexpr,
- batch_size,
- seq_num,
- cu_seqlens_q,
- stride_x_seq,
- stride_x_nheads,
- cp_rank,
- cp_size,
- BLOCK_H: triton.language.constexpr,
Triton kernel of the forward pass for applying YARN RoPE to MLA’s query. This kernel inplace modifies the input tensor Q.
Input: Q: [seq_len, batch_size, head_num, qk_head_dim + emb_dim] or [total_seq_len, head_num, qk_head_dim + emb_dim] COS/SIN: [max_seq_len, emb_dim]
batch_size: batch size for sbhd format, not used for thd format seq_num: number of sequences for thd format, not used for sbhd format cu_seqlens_q: [seq_num + 1] accumulated sequence lengths for thd format
- core.fusions.fused_mla_yarn_rope_apply.rotary_bwd_q_kernel(
- DO,
- COS,
- SIN,
- qk_head_dim,
- emb_dim: triton.language.constexpr,
- head_num: triton.language.constexpr,
- batch_size,
- seq_num,
- cu_seqlens_q,
- stride_x_seq,
- stride_x_nheads,
- cp_rank,
- cp_size,
- BLOCK_H: triton.language.constexpr,
Triton kernel of the backward pass for applying YARN RoPE to MLA’s query. This kernel inplace modifies the input tensor DO.
Input: DO: [seq_len, batch_size, head_num, qk_head_dim + emb_dim] or [total_seq_len, head_num, qk_head_dim + emb_dim] COS/SIN: [max_seq_len, emb_dim]
batch_size, seq_num, and cu_seqlens_q are the same as in the forward pass
- class core.fusions.fused_mla_yarn_rope_apply.ApplyMLARotaryEmbQ#
Bases:
torch.autograd.FunctionAutograd function for applying YARN RoPE to MLA’s query.
- static forward(
- ctx,
- q,
- cos,
- sin,
- qk_head_dim,
- emb_dim,
- cu_seqlens_q,
- cp_rank,
- cp_size,
- rotary_interleaved=False,
Forward function for ApplyMLARotaryEmbQ.
- Parameters:
q – [seq_len, batch_size, head_num, qk_head_dim + emb_dim] or [total_seq_len, head_num, qk_head_dim + emb_dim]
cos/sin – [max_seq_len, 1, 1, emb_dim]
cu_seqlens_q – [seq_num + 1] accumulated sequence lengths for thd format
rotary_interleaved – whether to apply RoPE interleaved, only supports False for now
- static backward(ctx, grad)#
Backward function for ApplyMLARotaryEmbQ.
- Parameters:
grad – [seq_len, batch_size, head_num, qk_head_dim + emb_dim] or [total_seq_len, head_num, qk_head_dim + emb_dim]
- core.fusions.fused_mla_yarn_rope_apply.fused_apply_mla_rope_for_q(
- t: torch.Tensor,
- cos: torch.Tensor,
- sin: torch.Tensor,
- qk_head_dim: int,
- emb_dim: int,
- cu_seqlens_q: Optional[torch.Tensor] = None,
- cp_rank: int = 0,
- cp_size: int = 1,
- rotary_interleaved: bool = False,
Fused function for applying YARN RoPE to MLA’s query. This function inplace modifies the input tensor t. Along the last dimension of t, the last emb_dim elements are applied with RoPE. The first qk_head_dim elements are not modified. It is an experimental feature and may change in future versions. It supports both sbhd and thd input formats.
For the notations below, seq_len is the length of the sequence per batch for sbhd format, total_seq_len is the total length of the sequences for thd format. max_seq_len is the maximum length of the sequences in the input tensor.
- Parameters:
t – [seq_len, batch_size, head_num, qk_head_dim + emb_dim] or [total_seq_len, head_num, qk_head_dim + emb_dim]
cos/sin – [max_seq_len, 1, 1, emb_dim]
cu_seqlens_q – [seq_num + 1] accumulated sequence lengths for thd format
rotary_interleaved – whether to apply RoPE interleaved, only supports False for now
- Returns:
inplace modified input tensor
- Return type:
t
- core.fusions.fused_mla_yarn_rope_apply.rotary_fwd_kv_kernel(
- KV,
- K_POS_EMB,
- O_KEY,
- O_VALUE,
- COS,
- SIN,
- emb_dim: triton.language.constexpr,
- k_dim: triton.language.constexpr,
- v_dim: triton.language.constexpr,
- head_num: triton.language.constexpr,
- batch_size,
- seq_num,
- cu_seqlens_kv,
- stride_kv_seq,
- stride_kv_nheads,
- stride_emb_seq,
- stride_k_seq,
- stride_k_nheads,
- stride_v_seq,
- stride_v_nheads,
- cp_rank,
- cp_size,
- BLOCK_H: triton.language.constexpr,
Triton kernel of the forward pass for applying YARN RoPE to MLA’s key and value. It splits the input tensor KV into key and value, and concatenates the processed RoPE to the key.
Input: KV: [seq_len, batch_size, head_num, k_dim + v_dim] or [total_seq_len, head_num, k_dim + v_dim] K_POS_EMB: [seq_len, batch_size, emb_dim] or [total_seq_len, emb_dim] COS/SIN: [max_seq_len, emb_dim]
batch_size: batch size for sbhd format, not used for thd format seq_num: number of sequences for thd format, not used for sbhd format cu_seqlens_kv: [seq_num + 1] accumulated sequence lengths for thd format
Output: O_KEY: [seq_len, batch_size, head_num, emb_dim + k_dim] or [total_seq_len, head_num, emb_dim + k_dim] O_VALUE: [seq_len, batch_size, head_num, v_dim] or [total_seq_len, head_num, v_dim]
- core.fusions.fused_mla_yarn_rope_apply.rotary_bwd_kv_kernel(
- dK,
- dV,
- dKV,
- dEMB,
- COS,
- SIN,
- emb_dim: triton.language.constexpr,
- k_dim: triton.language.constexpr,
- v_dim: triton.language.constexpr,
- head_num: triton.language.constexpr,
- batch_size,
- seq_num,
- cu_seqlens_kv,
- stride_dk_seq,
- stride_dk_nheads,
- stride_dv_seq,
- stride_dv_nheads,
- stride_dkv_seq,
- stride_dkv_nheads,
- stride_demb_seq,
- cp_rank,
- cp_size,
- BLOCK_H: triton.language.constexpr,
Triton kernel of the backward pass for applying YARN RoPE to MLA’s key and value.
Input: dK: [seq_len, batch_size, head_num, emb_dim + k_dim] or [total_seq_len, head_num, emb_dim + k_dim] dV: [seq_len, batch_size, head_num, v_dim] or [total_seq_len, head_num, v_dim] COS/SIN: [max_seq_len, emb_dim]
batch_size, seq_num, and cu_seqlens_kv are the same as in the forward pass
Output: dKV: [seq_len, batch_size, head_num, k_dim + v_dim] or [total_seq_len, head_num, k_dim + v_dim] dEMB: [seq_len, batch_size, emb_dim] or [total_seq_len, emb_dim]
- class core.fusions.fused_mla_yarn_rope_apply.ApplyMLARotaryEmbKV#
Bases:
torch.autograd.FunctionAutograd function for applying YARN RoPE to MLA’s key and value.
- static forward(
- ctx,
- kv,
- k_pos_emb,
- cos,
- sin,
- emb_dim,
- k_dim,
- v_dim,
- cu_seqlens_kv,
- cp_rank,
- cp_size,
- rotary_interleaved=False,
Forward function for ApplyMLARotaryEmbKV.
- Parameters:
kv – [seq_len, batch_size, head_num, k_dim + v_dim] or [total_seq_len, head_num, k_dim + v_dim]
k_pos_emb – [seq_len, batch_size, 1, emb_dim] or [total_seq_len, 1, emb_dim]
cos/sin – [max_seq_len, 1, 1, emb_dim]
cu_seqlens_kv – [seq_num + 1] accumulated sequence lengths for thd format
rotary_interleaved – whether to apply RoPE interleaved, only supports False for now
- static backward(ctx, dk, dv)#
Backward function for ApplyMLARotaryEmbKV.
- Parameters:
dk – [seq_len, batch_size, head_num, emb_dim + k_dim] or [total_seq_len, head_num, emb_dim + k_dim]
dv – [seq_len, batch_size, head_num, v_dim] or [total_seq_len, head_num, v_dim]
- core.fusions.fused_mla_yarn_rope_apply.fused_apply_mla_rope_for_kv(
- kv: torch.Tensor,
- k_pos_emb: torch.Tensor,
- cos: torch.Tensor,
- sin: torch.Tensor,
- emb_dim: int,
- k_dim: int,
- v_dim: int,
- cu_seqlens_kv: Optional[torch.Tensor] = None,
- cp_rank: int = 0,
- cp_size: int = 1,
- rotary_interleaved: bool = False,
Fused function for applying YARN RoPE to MLA’s key and value. It splits the input tensor kv into key and value, and concatenates the processed RoPE to the key.
For the notations below, seq_len is the length of sequence per batch for sbhd format, total_seq_len is the total length of the sequences for thd format. max_seq_len is the maximum length of the sequences in the input tensor.
- Parameters:
kv – [seq_len, batch_size, head_num, k_dim + v_dim] or [total_seq_len, head_num, k_dim + v_dim]
k_pos_emb – [seq_len, batch_size, 1, emb_dim] or [total_seq_len, 1, emb_dim]
cos/sin – [max_seq_len, 1, 1, emb_dim]
cu_seqlens_kv – [seq_num + 1] accumulated sequence lengths for thd format
rotary_interleaved – whether to apply RoPE interleaved, only supports False for now
- Returns:
[seq_len, batch_size, head_num, emb_dim + k_dim] or [total_seq_len, head_num, emb_dim + k_dim] value: [seq_len, batch_size, head_num, v_dim] or [total_seq_len, head_num, v_dim]
- Return type:
key