core.ssm.ops.ssd_chunk_state#
Module Contents#
Functions#
Compute per-sequence final SSM state from chunk states. |
API#
- core.ssm.ops.ssd_chunk_state._chunk_cumsum_fwd_kernel(
- dt_ptr,
- A_ptr,
- dt_bias_ptr,
- dt_out_ptr,
- dA_cumsum_ptr,
- cu_chunk_seqlens_ptr,
- seqlen,
- nheads: triton.language.constexpr,
- chunk_size: triton.language.constexpr,
- dt_min: triton.language.constexpr,
- dt_max: triton.language.constexpr,
- stride_dt_seqlen: triton.language.int64,
- stride_dt_head: triton.language.constexpr,
- stride_A_head: triton.language.constexpr,
- stride_dt_bias_head: triton.language.constexpr,
- stride_dt_out_head: triton.language.int64,
- stride_dt_out_chunk: triton.language.int64,
- stride_dt_out_csize: triton.language.constexpr,
- stride_dA_cs_head: triton.language.int64,
- stride_dA_cs_chunk: triton.language.int64,
- stride_dA_cs_csize: triton.language.constexpr,
- DT_SOFTPLUS: triton.language.constexpr,
- HAS_DT_BIAS: triton.language.constexpr,
- BLOCK_SIZE_H: triton.language.constexpr,
- BLOCK_SIZE_CHUNK: triton.language.constexpr,
- core.ssm.ops.ssd_chunk_state._chunk_state_fwd_kernel(
- x_ptr,
- b_ptr,
- states_ptr,
- dt_ptr,
- dA_cumsum_ptr,
- cu_chunk_seqlens_ptr,
- hdim: triton.language.constexpr,
- dstate: triton.language.constexpr,
- chunk_size: triton.language.constexpr,
- seqlen,
- nheads_ngroups_ratio: triton.language.constexpr,
- stride_x_seqlen: triton.language.int64,
- stride_x_head: triton.language.int64,
- stride_x_hdim: triton.language.constexpr,
- stride_b_seqlen: triton.language.int64,
- stride_b_head: triton.language.int64,
- stride_b_dstate: triton.language.constexpr,
- stride_states_chunk: triton.language.int64,
- stride_states_head: triton.language.int64,
- stride_states_hdim: triton.language.int64,
- stride_states_dstate: triton.language.constexpr,
- stride_dt_head: triton.language.int64,
- stride_dt_chunk: triton.language.int64,
- stride_dt_csize: triton.language.constexpr,
- stride_dA_cs_head: triton.language.int64,
- stride_dA_cs_chunk: triton.language.int64,
- stride_dA_cs_csize: triton.language.constexpr,
- BLOCK_SIZE_M: triton.language.constexpr,
- BLOCK_SIZE_N: triton.language.constexpr,
- BLOCK_SIZE_K: triton.language.constexpr,
- core.ssm.ops.ssd_chunk_state._chunk_cumsum_fwd(
- dt,
- A,
- chunk_size,
- cu_chunk_seqlens,
- dt_bias=None,
- dt_softplus=False,
- dt_limit=(0.0, float('inf')),
- core.ssm.ops.ssd_chunk_state._chunk_state_fwd(
- B,
- x,
- dt,
- dA_cumsum,
- cu_chunk_seqlens,
- states=None,
- states_in_fp32=True,
- core.ssm.ops.ssd_chunk_state._chunk_state_varlen_kernel(
- x_ptr,
- b_ptr,
- dt_ptr,
- dA_cumsum_ptr,
- chunk_states_ptr,
- cu_seqlens_ptr,
- last_chunk_indices_ptr,
- cu_chunk_seqlens_ptr,
- states_ptr,
- initstates_ptr,
- hdim: triton.language.constexpr,
- dstate: triton.language.constexpr,
- chunk_size: triton.language.constexpr,
- nheads_ngroups_ratio: triton.language.constexpr,
- stride_x_seqlen: triton.language.int64,
- stride_x_head: triton.language.int64,
- stride_x_hdim: triton.language.constexpr,
- stride_b_seqlen: triton.language.int64,
- stride_b_head: triton.language.int64,
- stride_b_dstate: triton.language.constexpr,
- stride_dt_head: triton.language.int64,
- stride_dt_chunk: triton.language.int64,
- stride_dt_csize: triton.language.constexpr,
- stride_dA_cs_head: triton.language.int64,
- stride_dA_cs_chunk: triton.language.int64,
- stride_dA_cs_csize: triton.language.constexpr,
- stride_chunk_states_chunk: triton.language.int64,
- stride_chunk_states_head: triton.language.int64,
- stride_chunk_states_hdim: triton.language.int64,
- stride_chunk_states_dstate: triton.language.constexpr,
- stride_states_batch: triton.language.int64,
- stride_states_head: triton.language.int64,
- stride_states_hdim: triton.language.int64,
- stride_states_dstate: triton.language.constexpr,
- stride_init_states_batch: triton.language.int64,
- stride_init_states_head: triton.language.int64,
- stride_init_states_hdim: triton.language.int64,
- stride_init_states_dstate: triton.language.constexpr,
- BLOCK_SIZE_M: triton.language.constexpr,
- BLOCK_SIZE_N: triton.language.constexpr,
- BLOCK_SIZE_K: triton.language.constexpr,
- HAS_INITSTATES: triton.language.constexpr,
- USE_LAST_CHUNK_INDICES: triton.language.constexpr,
- core.ssm.ops.ssd_chunk_state.chunk_state_varlen(
- B,
- x,
- dt,
- dA_cumsum,
- cu_seqlens,
- chunk_states,
- initial_states=None,
- last_chunk_indices=None,
- cu_chunk_seqlens=None,
Compute per-sequence final SSM state from chunk states.
Correct when sequences share chunks.