core.ssm.ops.ssd_chunk_scan#
Module Contents#
Functions#
Data#
API#
- core.ssm.ops.ssd_chunk_scan.TRITON_22#
None
- core.ssm.ops.ssd_chunk_scan._chunk_scan_fwd_kernel(
- cb_ptr,
- x_ptr,
- z_ptr,
- out_ptr,
- dt_ptr,
- dA_cumsum_ptr,
- seq_idx_ptr,
- C_ptr,
- states_ptr,
- D_ptr,
- initstates_ptr,
- cu_chunk_seqlens_ptr,
- chunk_size: triton.language.constexpr,
- hdim: triton.language.constexpr,
- dstate: triton.language.constexpr,
- seqlen,
- nheads_ngroups_ratio: triton.language.constexpr,
- stride_cb_chunk: triton.language.int64,
- stride_cb_head: triton.language.int64,
- stride_cb_csize_m: triton.language.int64,
- stride_cb_csize_k: triton.language.constexpr,
- stride_x_seqlen: triton.language.int64,
- stride_x_head: triton.language.int64,
- stride_x_hdim: triton.language.constexpr,
- stride_z_seqlen: triton.language.int64,
- stride_z_head: triton.language.int64,
- stride_z_hdim: triton.language.constexpr,
- stride_out_seqlen: triton.language.int64,
- stride_out_head: triton.language.int64,
- stride_out_hdim: triton.language.constexpr,
- stride_dt_chunk: triton.language.int64,
- stride_dt_head: triton.language.int64,
- stride_dt_csize: triton.language.constexpr,
- stride_dA_cs_chunk: triton.language.int64,
- stride_dA_cs_head: triton.language.int64,
- stride_dA_cs_csize: triton.language.constexpr,
- stride_seq_idx_chunk: triton.language.constexpr,
- stride_C_seqlen: triton.language.int64,
- stride_C_head: triton.language.int64,
- stride_C_dstate: triton.language.constexpr,
- stride_states_chunk: triton.language.int64,
- stride_states_head: triton.language.int64,
- stride_states_hdim: triton.language.int64,
- stride_states_dstate: triton.language.constexpr,
- stride_init_states_batch: triton.language.int64,
- stride_init_states_head: triton.language.int64,
- stride_init_states_hdim: triton.language.int64,
- stride_init_states_dstate: triton.language.constexpr,
- stride_D_head: triton.language.constexpr,
- IS_CAUSAL: triton.language.constexpr,
- HAS_D: triton.language.constexpr,
- D_HAS_HDIM: triton.language.constexpr,
- HAS_Z: triton.language.constexpr,
- BLOCK_SIZE_M: triton.language.constexpr,
- BLOCK_SIZE_N: triton.language.constexpr,
- BLOCK_SIZE_K: triton.language.constexpr,
- BLOCK_SIZE_DSTATE: triton.language.constexpr,
- IS_TRITON_22: triton.language.constexpr,
- HAS_INITSTATES: triton.language.constexpr,
- core.ssm.ops.ssd_chunk_scan._chunk_scan_fwd(
- cb,
- x,
- dt,
- dA_cumsum,
- C,
- states,
- cu_chunk_seqlens,
- out,
- seq_idx,
- D=None,
- z=None,
- initial_states=None,