core.ssm.ops.ssd_bmm#

Module Contents#

Functions#

_bmm_chunk_fwd_kernel

_bmm_chunk_fwd

Argument: a: (seqlen, ngroups, k) b: (seqlen, ngroups, k) chunk_size: int cu_chunk_seq_lens: (nchunks+1,) causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are guaranteed to be correct.

API#

core.ssm.ops.ssd_bmm._bmm_chunk_fwd_kernel(
a_ptr,
b_ptr,
out_ptr,
cu_chunk_seqlens_ptr,
seqlen,
chunk_size: triton.language.constexpr,
K: triton.language.constexpr,
ngroups: triton.language.constexpr,
stride_a_seqlen: triton.language.int64,
stride_a_head: triton.language.int64,
stride_ak: triton.language.constexpr,
stride_b_seqlen: triton.language.int64,
stride_b_head: triton.language.int64,
stride_bk: triton.language.constexpr,
stride_out_chunk: triton.language.int64,
stride_out_head: triton.language.int64,
stride_outm: triton.language.int64,
stride_outn: triton.language.constexpr,
IS_CAUSAL: triton.language.constexpr,
dot_dtype: triton.language.constexpr,
BLOCK_SIZE_M: triton.language.constexpr,
BLOCK_SIZE_N: triton.language.constexpr,
BLOCK_SIZE_K: triton.language.constexpr,
)#
core.ssm.ops.ssd_bmm._bmm_chunk_fwd(
a,
b,
chunk_size,
cu_chunk_seqlens,
causal=False,
output_dtype=None,
)#

Argument: a: (seqlen, ngroups, k) b: (seqlen, ngroups, k) chunk_size: int cu_chunk_seq_lens: (nchunks+1,) causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are guaranteed to be correct.

Returns:

(nchunks, ngroups, chunk_size, chunk_size)

Return type:

out