core.ssm.ops.ssd_state_passing#
Module Contents#
Functions#
API#
- core.ssm.ops.ssd_state_passing._state_passing_fwd_kernel(
- states_ptr,
- out_ptr,
- dA_cs_ptr,
- initstates_ptr,
- seq_idx_ptr,
- cu_chunk_seqlens_ptr,
- dim: triton.language.constexpr,
- nchunks,
- seqlen,
- chunk_size: triton.language.constexpr,
- stride_states_chunk: triton.language.int64,
- stride_states_head: triton.language.int64,
- stride_states_dim: triton.language.constexpr,
- stride_out_chunk: triton.language.int64,
- stride_out_head: triton.language.int64,
- stride_out_dim: triton.language.constexpr,
- stride_dA_cs_head: triton.language.int64,
- stride_dA_cs_chunk: triton.language.int64,
- stride_dA_cs_csize: triton.language.constexpr,
- stride_initstates_batch: triton.language.int64,
- stride_initstates_head: triton.language.int64,
- stride_initstates_dim: triton.language.constexpr,
- stride_seq_idx_chunk: triton.language.constexpr,
- HAS_INITSTATES: triton.language.constexpr,
- BLOCK_SIZE: triton.language.constexpr,
- core.ssm.ops.ssd_state_passing._state_passing_fwd(
- states,
- dA_cumsum,
- cu_chunk_seqlens,
- seq_idx,
- initial_states=None,
- out_dtype=None,