core.ssm.ops.ssd_bmm#
Module Contents#
Functions#
Argument: a: (seqlen, ngroups, k) b: (seqlen, ngroups, k) chunk_size: int cu_chunk_seq_lens: (nchunks+1,) causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are guaranteed to be correct. |
API#
- core.ssm.ops.ssd_bmm._bmm_chunk_fwd_kernel(
- a_ptr,
- b_ptr,
- out_ptr,
- cu_chunk_seqlens_ptr,
- seqlen,
- chunk_size: triton.language.constexpr,
- K: triton.language.constexpr,
- ngroups: triton.language.constexpr,
- stride_a_seqlen: triton.language.int64,
- stride_a_head: triton.language.int64,
- stride_ak: triton.language.constexpr,
- stride_b_seqlen: triton.language.int64,
- stride_b_head: triton.language.int64,
- stride_bk: triton.language.constexpr,
- stride_out_chunk: triton.language.int64,
- stride_out_head: triton.language.int64,
- stride_outm: triton.language.int64,
- stride_outn: triton.language.constexpr,
- IS_CAUSAL: triton.language.constexpr,
- dot_dtype: triton.language.constexpr,
- BLOCK_SIZE_M: triton.language.constexpr,
- BLOCK_SIZE_N: triton.language.constexpr,
- BLOCK_SIZE_K: triton.language.constexpr,
- core.ssm.ops.ssd_bmm._bmm_chunk_fwd(
- a,
- b,
- chunk_size,
- cu_chunk_seqlens,
- causal=False,
- output_dtype=None,
Argument: a: (seqlen, ngroups, k) b: (seqlen, ngroups, k) chunk_size: int cu_chunk_seq_lens: (nchunks+1,) causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are guaranteed to be correct.
- Returns:
(nchunks, ngroups, chunk_size, chunk_size)
- Return type:
out