core.ssm.ops.ssd_state_passing#

Module Contents#

Functions#

API#

core.ssm.ops.ssd_state_passing._state_passing_fwd_kernel(
states_ptr,
out_ptr,
dA_cs_ptr,
initstates_ptr,
seq_idx_ptr,
cu_chunk_seqlens_ptr,
dim: triton.language.constexpr,
nchunks,
seqlen,
chunk_size: triton.language.constexpr,
stride_states_chunk: triton.language.int64,
stride_states_head: triton.language.int64,
stride_states_dim: triton.language.constexpr,
stride_out_chunk: triton.language.int64,
stride_out_head: triton.language.int64,
stride_out_dim: triton.language.constexpr,
stride_dA_cs_head: triton.language.int64,
stride_dA_cs_chunk: triton.language.int64,
stride_dA_cs_csize: triton.language.constexpr,
stride_initstates_batch: triton.language.int64,
stride_initstates_head: triton.language.int64,
stride_initstates_dim: triton.language.constexpr,
stride_seq_idx_chunk: triton.language.constexpr,
HAS_INITSTATES: triton.language.constexpr,
BLOCK_SIZE: triton.language.constexpr,
)#
core.ssm.ops.ssd_state_passing._state_passing_fwd(
states,
dA_cumsum,
cu_chunk_seqlens,
seq_idx,
initial_states=None,
out_dtype=None,
)#