core.fusions.fused_mla_yarn_rope_apply#

Module Contents#

Classes#

ApplyMLARotaryEmbQ

Autograd function for applying YARN RoPE to MLA’s query.

ApplyMLARotaryEmbKV

Autograd function for applying YARN RoPE to MLA’s key and value.

Functions#

_get_thd_token_idx

rotary_fwd_q_kernel

Triton kernel of the forward pass for applying YARN RoPE to MLA’s query. This kernel inplace modifies the input tensor Q.

rotary_bwd_q_kernel

Triton kernel of the backward pass for applying YARN RoPE to MLA’s query. This kernel inplace modifies the input tensor DO.

fused_apply_mla_rope_for_q

Fused function for applying YARN RoPE to MLA’s query. This function inplace modifies the input tensor t. Along the last dimension of t, the last emb_dim elements are applied with RoPE. The first qk_head_dim elements are not modified. It is an experimental feature and may change in future versions. It supports both sbhd and thd input formats.

rotary_fwd_kv_kernel

Triton kernel of the forward pass for applying YARN RoPE to MLA’s key and value. It splits the input tensor KV into key and value, and concatenates the processed RoPE to the key.

rotary_bwd_kv_kernel

Triton kernel of the backward pass for applying YARN RoPE to MLA’s key and value.

fused_apply_mla_rope_for_kv

Fused function for applying YARN RoPE to MLA’s key and value. It splits the input tensor kv into key and value, and concatenates the processed RoPE to the key.

API#

core.fusions.fused_mla_yarn_rope_apply._get_thd_token_idx(cu_seqlens, pid_m, seq_num, cp_rank, cp_size)#
core.fusions.fused_mla_yarn_rope_apply.rotary_fwd_q_kernel(
Q,
COS,
SIN,
qk_head_dim,
emb_dim: triton.language.constexpr,
head_num: triton.language.constexpr,
batch_size,
seq_num,
cu_seqlens_q,
stride_x_seq,
stride_x_nheads,
cp_rank,
cp_size,
BLOCK_H: triton.language.constexpr,
)#

Triton kernel of the forward pass for applying YARN RoPE to MLA’s query. This kernel inplace modifies the input tensor Q.

Input: Q: [seq_len, batch_size, head_num, qk_head_dim + emb_dim] or [total_seq_len, head_num, qk_head_dim + emb_dim] COS/SIN: [max_seq_len, emb_dim]

batch_size: batch size for sbhd format, not used for thd format
seq_num: number of sequences for thd format, not used for sbhd format
cu_seqlens_q: [seq_num + 1] accumulated sequence lengths for thd format
core.fusions.fused_mla_yarn_rope_apply.rotary_bwd_q_kernel(
DO,
COS,
SIN,
qk_head_dim,
emb_dim: triton.language.constexpr,
head_num: triton.language.constexpr,
batch_size,
seq_num,
cu_seqlens_q,
stride_x_seq,
stride_x_nheads,
cp_rank,
cp_size,
BLOCK_H: triton.language.constexpr,
)#

Triton kernel of the backward pass for applying YARN RoPE to MLA’s query. This kernel inplace modifies the input tensor DO.

Input: DO: [seq_len, batch_size, head_num, qk_head_dim + emb_dim] or [total_seq_len, head_num, qk_head_dim + emb_dim] COS/SIN: [max_seq_len, emb_dim]

batch_size, seq_num, and cu_seqlens_q are the same as in the forward pass
class core.fusions.fused_mla_yarn_rope_apply.ApplyMLARotaryEmbQ#

Bases: torch.autograd.Function

Autograd function for applying YARN RoPE to MLA’s query.

static forward(
ctx,
q,
cos,
sin,
qk_head_dim,
emb_dim,
cu_seqlens_q,
cp_rank,
cp_size,
rotary_interleaved=False,
)#

Forward function for ApplyMLARotaryEmbQ.

Parameters:
  • q – [seq_len, batch_size, head_num, qk_head_dim + emb_dim] or [total_seq_len, head_num, qk_head_dim + emb_dim]

  • cos/sin – [max_seq_len, 1, 1, emb_dim]

  • cu_seqlens_q – [seq_num + 1] accumulated sequence lengths for thd format

  • rotary_interleaved – whether to apply RoPE interleaved, only supports False for now

static backward(ctx, grad)#

Backward function for ApplyMLARotaryEmbQ.

Parameters:

grad – [seq_len, batch_size, head_num, qk_head_dim + emb_dim] or [total_seq_len, head_num, qk_head_dim + emb_dim]

core.fusions.fused_mla_yarn_rope_apply.fused_apply_mla_rope_for_q(
t: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
qk_head_dim: int,
emb_dim: int,
cu_seqlens_q: Optional[torch.Tensor] = None,
cp_rank: int = 0,
cp_size: int = 1,
rotary_interleaved: bool = False,
)#

Fused function for applying YARN RoPE to MLA’s query. This function inplace modifies the input tensor t. Along the last dimension of t, the last emb_dim elements are applied with RoPE. The first qk_head_dim elements are not modified. It is an experimental feature and may change in future versions. It supports both sbhd and thd input formats.

For the notations below, seq_len is the length of the sequence per batch for sbhd format, total_seq_len is the total length of the sequences for thd format. max_seq_len is the maximum length of the sequences in the input tensor.

Parameters:
  • t – [seq_len, batch_size, head_num, qk_head_dim + emb_dim] or [total_seq_len, head_num, qk_head_dim + emb_dim]

  • cos/sin – [max_seq_len, 1, 1, emb_dim]

  • cu_seqlens_q – [seq_num + 1] accumulated sequence lengths for thd format

  • rotary_interleaved – whether to apply RoPE interleaved, only supports False for now

Returns:

inplace modified input tensor

Return type:

t

core.fusions.fused_mla_yarn_rope_apply.rotary_fwd_kv_kernel(
KV,
K_POS_EMB,
O_KEY,
O_VALUE,
COS,
SIN,
emb_dim: triton.language.constexpr,
k_dim: triton.language.constexpr,
v_dim: triton.language.constexpr,
head_num: triton.language.constexpr,
batch_size,
seq_num,
cu_seqlens_kv,
stride_kv_seq,
stride_kv_nheads,
stride_emb_seq,
stride_k_seq,
stride_k_nheads,
stride_v_seq,
stride_v_nheads,
cp_rank,
cp_size,
BLOCK_H: triton.language.constexpr,
)#

Triton kernel of the forward pass for applying YARN RoPE to MLA’s key and value. It splits the input tensor KV into key and value, and concatenates the processed RoPE to the key.

Input: KV: [seq_len, batch_size, head_num, k_dim + v_dim] or [total_seq_len, head_num, k_dim + v_dim] K_POS_EMB: [seq_len, batch_size, emb_dim] or [total_seq_len, emb_dim] COS/SIN: [max_seq_len, emb_dim]

batch_size: batch size for sbhd format, not used for thd format
seq_num: number of sequences for thd format, not used for sbhd format
cu_seqlens_kv: [seq_num + 1] accumulated sequence lengths for thd format

Output: O_KEY: [seq_len, batch_size, head_num, emb_dim + k_dim] or [total_seq_len, head_num, emb_dim + k_dim] O_VALUE: [seq_len, batch_size, head_num, v_dim] or [total_seq_len, head_num, v_dim]

core.fusions.fused_mla_yarn_rope_apply.rotary_bwd_kv_kernel(
dK,
dV,
dKV,
dEMB,
COS,
SIN,
emb_dim: triton.language.constexpr,
k_dim: triton.language.constexpr,
v_dim: triton.language.constexpr,
head_num: triton.language.constexpr,
batch_size,
seq_num,
cu_seqlens_kv,
stride_dk_seq,
stride_dk_nheads,
stride_dv_seq,
stride_dv_nheads,
stride_dkv_seq,
stride_dkv_nheads,
stride_demb_seq,
cp_rank,
cp_size,
BLOCK_H: triton.language.constexpr,
)#

Triton kernel of the backward pass for applying YARN RoPE to MLA’s key and value.

Input: dK: [seq_len, batch_size, head_num, emb_dim + k_dim] or [total_seq_len, head_num, emb_dim + k_dim] dV: [seq_len, batch_size, head_num, v_dim] or [total_seq_len, head_num, v_dim] COS/SIN: [max_seq_len, emb_dim]

batch_size, seq_num, and cu_seqlens_kv are the same as in the forward pass

Output: dKV: [seq_len, batch_size, head_num, k_dim + v_dim] or [total_seq_len, head_num, k_dim + v_dim] dEMB: [seq_len, batch_size, emb_dim] or [total_seq_len, emb_dim]

class core.fusions.fused_mla_yarn_rope_apply.ApplyMLARotaryEmbKV#

Bases: torch.autograd.Function

Autograd function for applying YARN RoPE to MLA’s key and value.

static forward(
ctx,
kv,
k_pos_emb,
cos,
sin,
emb_dim,
k_dim,
v_dim,
cu_seqlens_kv,
cp_rank,
cp_size,
rotary_interleaved=False,
)#

Forward function for ApplyMLARotaryEmbKV.

Parameters:
  • kv – [seq_len, batch_size, head_num, k_dim + v_dim] or [total_seq_len, head_num, k_dim + v_dim]

  • k_pos_emb – [seq_len, batch_size, 1, emb_dim] or [total_seq_len, 1, emb_dim]

  • cos/sin – [max_seq_len, 1, 1, emb_dim]

  • cu_seqlens_kv – [seq_num + 1] accumulated sequence lengths for thd format

  • rotary_interleaved – whether to apply RoPE interleaved, only supports False for now

static backward(ctx, dk, dv)#

Backward function for ApplyMLARotaryEmbKV.

Parameters:
  • dk – [seq_len, batch_size, head_num, emb_dim + k_dim] or [total_seq_len, head_num, emb_dim + k_dim]

  • dv – [seq_len, batch_size, head_num, v_dim] or [total_seq_len, head_num, v_dim]

core.fusions.fused_mla_yarn_rope_apply.fused_apply_mla_rope_for_kv(
kv: torch.Tensor,
k_pos_emb: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
emb_dim: int,
k_dim: int,
v_dim: int,
cu_seqlens_kv: Optional[torch.Tensor] = None,
cp_rank: int = 0,
cp_size: int = 1,
rotary_interleaved: bool = False,
)#

Fused function for applying YARN RoPE to MLA’s key and value. It splits the input tensor kv into key and value, and concatenates the processed RoPE to the key.

For the notations below, seq_len is the length of sequence per batch for sbhd format, total_seq_len is the total length of the sequences for thd format. max_seq_len is the maximum length of the sequences in the input tensor.

Parameters:
  • kv – [seq_len, batch_size, head_num, k_dim + v_dim] or [total_seq_len, head_num, k_dim + v_dim]

  • k_pos_emb – [seq_len, batch_size, 1, emb_dim] or [total_seq_len, 1, emb_dim]

  • cos/sin – [max_seq_len, 1, 1, emb_dim]

  • cu_seqlens_kv – [seq_num + 1] accumulated sequence lengths for thd format

  • rotary_interleaved – whether to apply RoPE interleaved, only supports False for now

Returns:

[seq_len, batch_size, head_num, emb_dim + k_dim] or [total_seq_len, head_num, emb_dim + k_dim] value: [seq_len, batch_size, head_num, v_dim] or [total_seq_len, head_num, v_dim]

Return type:

key