core.ssm.ops.causal_conv1d_varlen#
Triton varlen depthwise causal 1D convolution with per-sequence initial states and fused SiLU.
Supports packed variable-length sequences where causal_conv1d_fn cannot accept
both seq_idx and initial_states simultaneously.
Module Contents#
Functions#
Depthwise causal conv1d over packed varlen sequences with initial states and SiLU. |
|
Depthwise causal 1D convolution over packed variable-length sequences. |
|
Simple PyTorch implementation of varlen causal conv1d with initial states and SiLU. |
API#
- core.ssm.ops.causal_conv1d_varlen._causal_conv1d_varlen_kernel(
- x_ptr,
- weight_ptr,
- bias_ptr,
- seq_idx_ptr,
- seq_start_ptr,
- initial_states_ptr,
- out_ptr,
- total_tokens,
- conv_dim: triton.language.constexpr,
- initial_states_stride_req,
- initial_states_stride_dim,
- WIDTH: triton.language.constexpr,
- BLOCK_T: triton.language.constexpr,
- BLOCK_C: triton.language.constexpr,
- HAS_INITIAL_STATES: triton.language.constexpr,
Depthwise causal conv1d over packed varlen sequences with initial states and SiLU.
Fully vectorized over BLOCK_T tokens x BLOCK_C channels per thread block.
- core.ssm.ops.causal_conv1d_varlen.causal_conv1d_varlen_fn(
- x: torch.Tensor,
- weight: torch.Tensor,
- bias: torch.Tensor,
- cu_seqlens: torch.Tensor,
- initial_states: torch.Tensor = None,
- activation: str = 'silu',
- precomputed_seq_idx: torch.Tensor = None,
- precomputed_seq_start: torch.Tensor = None,
Depthwise causal 1D convolution over packed variable-length sequences.
Supports both
cu_seqlens(sequence boundaries) andinitial_statessimultaneously, unlikecausal_conv1d_fnwhich requires mutual exclusivity betweenseq_idxandinitial_states.- Parameters:
x – Input tensor of shape (total_tokens, conv_dim), channels-last packed.
weight – Convolution weights of shape (conv_dim, d_conv).
bias – Bias of shape (conv_dim,).
cu_seqlens – Cumulative sequence lengths of shape (num_requests + 1,), int32.
initial_states – Per-request initial conv states of shape (num_requests, conv_dim, d_conv - 1). If None, uses zeros.
activation – Activation function, must be “silu”.
precomputed_seq_idx – Precomputed per-token request ID of shape (total_tokens,). If provided, skips repeat_interleave (CUDA graph compatible). Padding tokens should use 0 as sentinel.
precomputed_seq_start – Precomputed per-token request start position of shape (total_tokens,). Must be provided together with precomputed_seq_idx.
- Returns:
Output tensor of shape (total_tokens, conv_dim).
- core.ssm.ops.causal_conv1d_varlen._causal_conv1d_varlen_simple(
- x: torch.Tensor,
- weight: torch.Tensor,
- bias: torch.Tensor,
- cu_seqlens: torch.Tensor,
- initial_states: torch.Tensor,
- out: torch.Tensor,
Simple PyTorch implementation of varlen causal conv1d with initial states and SiLU.
This is a reference implementation for testing. Processes each request and token sequentially.