core.ssm.ops.ssd_chunk_scan#

Module Contents#

Functions#

Data#

API#

core.ssm.ops.ssd_chunk_scan.TRITON_22#

None

core.ssm.ops.ssd_chunk_scan._chunk_scan_fwd_kernel(
cb_ptr,
x_ptr,
z_ptr,
out_ptr,
dt_ptr,
dA_cumsum_ptr,
seq_idx_ptr,
C_ptr,
states_ptr,
D_ptr,
initstates_ptr,
cu_chunk_seqlens_ptr,
chunk_size: triton.language.constexpr,
hdim: triton.language.constexpr,
dstate: triton.language.constexpr,
seqlen,
nheads_ngroups_ratio: triton.language.constexpr,
stride_cb_chunk: triton.language.int64,
stride_cb_head: triton.language.int64,
stride_cb_csize_m: triton.language.int64,
stride_cb_csize_k: triton.language.constexpr,
stride_x_seqlen: triton.language.int64,
stride_x_head: triton.language.int64,
stride_x_hdim: triton.language.constexpr,
stride_z_seqlen: triton.language.int64,
stride_z_head: triton.language.int64,
stride_z_hdim: triton.language.constexpr,
stride_out_seqlen: triton.language.int64,
stride_out_head: triton.language.int64,
stride_out_hdim: triton.language.constexpr,
stride_dt_chunk: triton.language.int64,
stride_dt_head: triton.language.int64,
stride_dt_csize: triton.language.constexpr,
stride_dA_cs_chunk: triton.language.int64,
stride_dA_cs_head: triton.language.int64,
stride_dA_cs_csize: triton.language.constexpr,
stride_seq_idx_chunk: triton.language.constexpr,
stride_C_seqlen: triton.language.int64,
stride_C_head: triton.language.int64,
stride_C_dstate: triton.language.constexpr,
stride_states_chunk: triton.language.int64,
stride_states_head: triton.language.int64,
stride_states_hdim: triton.language.int64,
stride_states_dstate: triton.language.constexpr,
stride_init_states_batch: triton.language.int64,
stride_init_states_head: triton.language.int64,
stride_init_states_hdim: triton.language.int64,
stride_init_states_dstate: triton.language.constexpr,
stride_D_head: triton.language.constexpr,
IS_CAUSAL: triton.language.constexpr,
HAS_D: triton.language.constexpr,
D_HAS_HDIM: triton.language.constexpr,
HAS_Z: triton.language.constexpr,
BLOCK_SIZE_M: triton.language.constexpr,
BLOCK_SIZE_N: triton.language.constexpr,
BLOCK_SIZE_K: triton.language.constexpr,
BLOCK_SIZE_DSTATE: triton.language.constexpr,
IS_TRITON_22: triton.language.constexpr,
HAS_INITSTATES: triton.language.constexpr,
)#
core.ssm.ops.ssd_chunk_scan._chunk_scan_fwd(
cb,
x,
dt,
dA_cumsum,
C,
states,
cu_chunk_seqlens,
out,
seq_idx,
D=None,
z=None,
initial_states=None,
)#