core.ssm.ops.causal_conv1d_triton#
Module Contents#
Functions#
Triton implementation of causal_conv1d_update (kernel). |
|
Triton implementation of causal_conv1d_update (entrypoint). |
API#
- core.ssm.ops.causal_conv1d_triton.causal_conv1d_update_kernel(
- x_ptr,
- x_b_stride,
- x_s_stride,
- x_c_stride,
- conv_state_ptr,
- conv_state_b_stride,
- conv_state_c_stride,
- conv_state_l_stride,
- int_state_ptr,
- int_state_b_stride,
- int_state_s_stride,
- int_state_c_stride,
- int_state_l_stride,
- weight_ptr,
- weight_c_stride,
- weight_width_stride,
- bias_ptr,
- bias_stride,
- out_ptr,
- out_b_stride,
- out_s_stride,
- out_c_stride,
- conv_state_indices_ptr,
- batch,
- seq_len,
- dim,
- state_len,
- WIDTH: triton.language.constexpr,
- BLOCK_DIM: triton.language.constexpr,
- HAS_BIAS: triton.language.constexpr,
- HAS_STATE_INDICES: triton.language.constexpr,
- HAS_INT_STATE: triton.language.constexpr,
- SILU_ACTIVATION: triton.language.constexpr,
Triton implementation of causal_conv1d_update (kernel).
- core.ssm.ops.causal_conv1d_triton.causal_conv1d_update(
- x: torch.Tensor,
- conv_state: torch.Tensor,
- weight: torch.Tensor,
- bias: torch.Tensor | None,
- silu_activation: bool,
- conv_state_indices: torch.Tensor | None,
- intermediate_conv_states: torch.Tensor | None = None,
Triton implementation of causal_conv1d_update (entrypoint).