core.ssm.ops.ssd_chunk_state#

Module Contents#

Functions#

API#

core.ssm.ops.ssd_chunk_state._chunk_cumsum_fwd_kernel(
dt_ptr,
A_ptr,
dt_bias_ptr,
dt_out_ptr,
dA_cumsum_ptr,
cu_chunk_seqlens_ptr,
seqlen,
nheads: triton.language.constexpr,
chunk_size: triton.language.constexpr,
dt_min: triton.language.constexpr,
dt_max: triton.language.constexpr,
stride_dt_seqlen: triton.language.int64,
stride_dt_head: triton.language.constexpr,
stride_A_head: triton.language.constexpr,
stride_dt_bias_head: triton.language.constexpr,
stride_dt_out_head: triton.language.int64,
stride_dt_out_chunk: triton.language.int64,
stride_dt_out_csize: triton.language.constexpr,
stride_dA_cs_head: triton.language.int64,
stride_dA_cs_chunk: triton.language.int64,
stride_dA_cs_csize: triton.language.constexpr,
DT_SOFTPLUS: triton.language.constexpr,
HAS_DT_BIAS: triton.language.constexpr,
BLOCK_SIZE_H: triton.language.constexpr,
BLOCK_SIZE_CHUNK: triton.language.constexpr,
)#
core.ssm.ops.ssd_chunk_state._chunk_state_fwd_kernel(
x_ptr,
b_ptr,
states_ptr,
dt_ptr,
dA_cumsum_ptr,
cu_chunk_seqlens_ptr,
hdim: triton.language.constexpr,
dstate: triton.language.constexpr,
chunk_size: triton.language.constexpr,
seqlen,
nheads_ngroups_ratio: triton.language.constexpr,
stride_x_seqlen: triton.language.int64,
stride_x_head: triton.language.int64,
stride_x_hdim: triton.language.constexpr,
stride_b_seqlen: triton.language.int64,
stride_b_head: triton.language.int64,
stride_b_dstate: triton.language.constexpr,
stride_states_chunk: triton.language.int64,
stride_states_head: triton.language.int64,
stride_states_hdim: triton.language.int64,
stride_states_dstate: triton.language.constexpr,
stride_dt_head: triton.language.int64,
stride_dt_chunk: triton.language.int64,
stride_dt_csize: triton.language.constexpr,
stride_dA_cs_head: triton.language.int64,
stride_dA_cs_chunk: triton.language.int64,
stride_dA_cs_csize: triton.language.constexpr,
BLOCK_SIZE_M: triton.language.constexpr,
BLOCK_SIZE_N: triton.language.constexpr,
BLOCK_SIZE_K: triton.language.constexpr,
)#
core.ssm.ops.ssd_chunk_state._chunk_cumsum_fwd(
dt,
A,
chunk_size,
cu_chunk_seqlens,
dt_bias=None,
dt_softplus=False,
dt_limit=(0.0, float('inf')),
)#
core.ssm.ops.ssd_chunk_state._chunk_state_fwd(
B,
x,
dt,
dA_cumsum,
cu_chunk_seqlens,
states=None,
states_in_fp32=True,
)#
core.ssm.ops.ssd_chunk_state._chunk_state_varlen_kernel(
x_ptr,
b_ptr,
dt_ptr,
dA_cumsum_ptr,
chunk_states_ptr,
cu_seqlens_ptr,
last_chunk_indices_ptr,
cu_chunk_seqlens_ptr,
states_ptr,
initstates_ptr,
hdim: triton.language.constexpr,
dstate: triton.language.constexpr,
chunk_size: triton.language.constexpr,
nheads_ngroups_ratio: triton.language.constexpr,
stride_x_seqlen: triton.language.int64,
stride_x_head: triton.language.int64,
stride_x_hdim: triton.language.constexpr,
stride_b_seqlen: triton.language.int64,
stride_b_head: triton.language.int64,
stride_b_dstate: triton.language.constexpr,
stride_dt_head: triton.language.int64,
stride_dt_chunk: triton.language.int64,
stride_dt_csize: triton.language.constexpr,
stride_dA_cs_head: triton.language.int64,
stride_dA_cs_chunk: triton.language.int64,
stride_dA_cs_csize: triton.language.constexpr,
stride_chunk_states_chunk: triton.language.int64,
stride_chunk_states_head: triton.language.int64,
stride_chunk_states_hdim: triton.language.int64,
stride_chunk_states_dstate: triton.language.constexpr,
stride_states_batch: triton.language.int64,
stride_states_head: triton.language.int64,
stride_states_hdim: triton.language.int64,
stride_states_dstate: triton.language.constexpr,
stride_init_states_batch: triton.language.int64,
stride_init_states_head: triton.language.int64,
stride_init_states_hdim: triton.language.int64,
stride_init_states_dstate: triton.language.constexpr,
BLOCK_SIZE_M: triton.language.constexpr,
BLOCK_SIZE_N: triton.language.constexpr,
BLOCK_SIZE_K: triton.language.constexpr,
HAS_INITSTATES: triton.language.constexpr,
USE_LAST_CHUNK_INDICES: triton.language.constexpr,
)#
core.ssm.ops.ssd_chunk_state.chunk_state_varlen(
B,
x,
dt,
dA_cumsum,
cu_seqlens,
chunk_states,
initial_states=None,
last_chunk_indices=None,
cu_chunk_seqlens=None,
)#

Compute per-sequence final SSM state from chunk states.

Correct when sequences share chunks.