core.ssm.ops.causal_conv1d_varlen#

Triton varlen depthwise causal 1D convolution with per-sequence initial states and fused SiLU.

Supports packed variable-length sequences where causal_conv1d_fn cannot accept both seq_idx and initial_states simultaneously.

Module Contents#

Functions#

_causal_conv1d_varlen_kernel

Depthwise causal conv1d over packed varlen sequences with initial states and SiLU.

causal_conv1d_varlen_fn

Depthwise causal 1D convolution over packed variable-length sequences.

_causal_conv1d_varlen_simple

Simple PyTorch implementation of varlen causal conv1d with initial states and SiLU.

API#

core.ssm.ops.causal_conv1d_varlen._causal_conv1d_varlen_kernel(
x_ptr,
weight_ptr,
bias_ptr,
seq_idx_ptr,
seq_start_ptr,
initial_states_ptr,
out_ptr,
total_tokens,
conv_dim: triton.language.constexpr,
initial_states_stride_req,
initial_states_stride_dim,
WIDTH: triton.language.constexpr,
BLOCK_T: triton.language.constexpr,
BLOCK_C: triton.language.constexpr,
HAS_INITIAL_STATES: triton.language.constexpr,
)#

Depthwise causal conv1d over packed varlen sequences with initial states and SiLU.

Fully vectorized over BLOCK_T tokens x BLOCK_C channels per thread block.

core.ssm.ops.causal_conv1d_varlen.causal_conv1d_varlen_fn(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
cu_seqlens: torch.Tensor,
initial_states: torch.Tensor = None,
activation: str = 'silu',
precomputed_seq_idx: torch.Tensor = None,
precomputed_seq_start: torch.Tensor = None,
) torch.Tensor#

Depthwise causal 1D convolution over packed variable-length sequences.

Supports both cu_seqlens (sequence boundaries) and initial_states simultaneously, unlike causal_conv1d_fn which requires mutual exclusivity between seq_idx and initial_states.

Parameters:
  • x – Input tensor of shape (total_tokens, conv_dim), channels-last packed.

  • weight – Convolution weights of shape (conv_dim, d_conv).

  • bias – Bias of shape (conv_dim,).

  • cu_seqlens – Cumulative sequence lengths of shape (num_requests + 1,), int32.

  • initial_states – Per-request initial conv states of shape (num_requests, conv_dim, d_conv - 1). If None, uses zeros.

  • activation – Activation function, must be “silu”.

  • precomputed_seq_idx – Precomputed per-token request ID of shape (total_tokens,). If provided, skips repeat_interleave (CUDA graph compatible). Padding tokens should use 0 as sentinel.

  • precomputed_seq_start – Precomputed per-token request start position of shape (total_tokens,). Must be provided together with precomputed_seq_idx.

Returns:

Output tensor of shape (total_tokens, conv_dim).

core.ssm.ops.causal_conv1d_varlen._causal_conv1d_varlen_simple(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
cu_seqlens: torch.Tensor,
initial_states: torch.Tensor,
out: torch.Tensor,
) None#

Simple PyTorch implementation of varlen causal conv1d with initial states and SiLU.

This is a reference implementation for testing. Processes each request and token sequentially.