core.ssm.ops.causal_conv1d_triton#

Module Contents#

Functions#

causal_conv1d_update_kernel

Triton implementation of causal_conv1d_update (kernel).

causal_conv1d_update

Triton implementation of causal_conv1d_update (entrypoint).

API#

core.ssm.ops.causal_conv1d_triton.causal_conv1d_update_kernel(
x_ptr,
x_b_stride,
x_s_stride,
x_c_stride,
conv_state_ptr,
conv_state_b_stride,
conv_state_c_stride,
conv_state_l_stride,
int_state_ptr,
int_state_b_stride,
int_state_s_stride,
int_state_c_stride,
int_state_l_stride,
weight_ptr,
weight_c_stride,
weight_width_stride,
bias_ptr,
bias_stride,
out_ptr,
out_b_stride,
out_s_stride,
out_c_stride,
conv_state_indices_ptr,
batch,
seq_len,
dim,
state_len,
WIDTH: triton.language.constexpr,
BLOCK_DIM: triton.language.constexpr,
HAS_BIAS: triton.language.constexpr,
HAS_STATE_INDICES: triton.language.constexpr,
HAS_INT_STATE: triton.language.constexpr,
SILU_ACTIVATION: triton.language.constexpr,
)#

Triton implementation of causal_conv1d_update (kernel).

core.ssm.ops.causal_conv1d_triton.causal_conv1d_update(
x: torch.Tensor,
conv_state: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None,
silu_activation: bool,
conv_state_indices: torch.Tensor | None,
intermediate_conv_states: torch.Tensor | None = None,
) torch.Tensor#

Triton implementation of causal_conv1d_update (entrypoint).