core.ssm.ops.mamba_ssm#
Module Contents#
Functions#
Argument: state: (batch, dim, dstate) or (batch, nheads, dim, dstate) x: (batch, dim), (batch, seqlen, dim), (batch, nheads, dim) or (batch, seqlen, nheads, dim) dt: Matches x A: (dim, dstate) or (nheads, dim, dstate) B: (batch, dstate), (batch, seqlen, dstate), (batch, ngroups, dstate) or (batch, seqlen, ngroups, dstate) C: Matches B D: (dim,) or (nheads, dim) z: Matches x dt_bias: (dim,) or (nheads, dim) intermediate_ssm_states: Optional buffer of shape (batch, seqlen, nheads, dim, dstate) or (batch, seqlen, dim, dstate) |
API#
- core.ssm.ops.mamba_ssm._selective_scan_update_kernel(
- state_ptr,
- x_ptr,
- dt_ptr,
- dt_bias_ptr,
- A_ptr,
- B_ptr,
- C_ptr,
- D_ptr,
- z_ptr,
- out_ptr,
- state_batch_indices_ptr,
- int_state_ptr,
- batch,
- seq_len,
- nheads,
- dim,
- dstate,
- nheads_ngroups_ratio,
- stride_state_batch,
- stride_state_head,
- stride_state_dim,
- stride_state_dstate,
- stride_x_batch,
- stride_x_seq,
- stride_x_head,
- stride_x_dim,
- stride_dt_batch,
- stride_dt_seq,
- stride_dt_head,
- stride_dt_dim,
- stride_dt_bias_head,
- stride_dt_bias_dim,
- stride_A_head,
- stride_A_dim,
- stride_A_dstate,
- stride_B_batch,
- stride_B_seq,
- stride_B_group,
- stride_B_dstate,
- stride_C_batch,
- stride_C_seq,
- stride_C_group,
- stride_C_dstate,
- stride_D_head,
- stride_D_dim,
- stride_z_batch,
- stride_z_seq,
- stride_z_head,
- stride_z_dim,
- stride_out_batch,
- stride_out_seq,
- stride_out_head,
- stride_out_dim,
- stride_int_batch,
- stride_int_seq,
- stride_int_head,
- stride_int_dim,
- stride_int_dstate,
- DT_SOFTPLUS: triton.language.constexpr,
- TIE_HDIM: triton.language.constexpr,
- BLOCK_SIZE_M: triton.language.constexpr,
- HAS_DT_BIAS: triton.language.constexpr,
- HAS_D: triton.language.constexpr,
- HAS_Z: triton.language.constexpr,
- HAS_STATE_BATCH_INDICES: triton.language.constexpr,
- HAS_INT_STATE: triton.language.constexpr,
- BLOCK_SIZE_DSTATE: triton.language.constexpr,
- core.ssm.ops.mamba_ssm.selective_state_update(
- state,
- x,
- dt,
- A,
- B,
- C,
- D=None,
- z=None,
- dt_bias=None,
- dt_softplus=False,
- state_batch_indices=None,
- intermediate_ssm_states=None,
Argument: state: (batch, dim, dstate) or (batch, nheads, dim, dstate) x: (batch, dim), (batch, seqlen, dim), (batch, nheads, dim) or (batch, seqlen, nheads, dim) dt: Matches x A: (dim, dstate) or (nheads, dim, dstate) B: (batch, dstate), (batch, seqlen, dstate), (batch, ngroups, dstate) or (batch, seqlen, ngroups, dstate) C: Matches B D: (dim,) or (nheads, dim) z: Matches x dt_bias: (dim,) or (nheads, dim) intermediate_ssm_states: Optional buffer of shape (batch, seqlen, nheads, dim, dstate) or (batch, seqlen, dim, dstate)
- Returns:
shape matches x
- Return type:
out