core.ssm.ops.ssd_combined#

Module Contents#

Functions#

is_int_pow_2

Return True if n is a positive integer power of 2.

_mamba_chunk_scan_combined_fwd

mamba_chunk_scan_combined_varlen

Argument: x: (seqlen, nheads, headdim) dt: (seqlen, nheads) A: (nheads) B: (seqlen, ngroups, dstate) C: (seqlen, ngroups, dstate) chunk_size: int cu_chunk_seqlens: (nchunks + 1,) last_chunk_indices: (batch,) seq_idx: (nchunks,) out: (seqlen, nheads, headdim) preallocated output tensor D: (nheads, headdim) or (nheads,) z: (seqlen, nheads, headdim) dt_bias: (nheads,) initial_states: (batch, nheads, headdim, dstate) dt_softplus: Whether to apply softplus to dt intermediate_chunk_indices: (N,) optional int64 tensor of chunk indices at which to extract intermediate SSM states. When provided, returns (final_states, intermediate_states) instead of just final_states. state_dtype: The data type of the ssm state

Data#

API#

core.ssm.ops.ssd_combined.TRITON_22#

None

core.ssm.ops.ssd_combined.is_int_pow_2(n)#

Return True if n is a positive integer power of 2.

core.ssm.ops.ssd_combined._mamba_chunk_scan_combined_fwd(
x,
dt,
A,
B,
C,
chunk_size,
out,
D=None,
z=None,
dt_bias=None,
initial_states=None,
return_intermediate_states=False,
seq_idx=None,
cu_chunk_seqlens=None,
last_chunk_indices=None,
intermediate_chunk_indices=None,
dt_softplus=False,
dt_limit=(0.0, float('inf')),
state_dtype=None,
)#
core.ssm.ops.ssd_combined.mamba_chunk_scan_combined_varlen(
x,
dt,
A,
B,
C,
chunk_size,
cu_chunk_seqlens,
last_chunk_indices,
seq_idx,
out,
D=None,
z=None,
dt_bias=None,
initial_states=None,
dt_softplus=False,
dt_limit=(0.0, float('inf')),
return_intermediate_states=False,
intermediate_chunk_indices=None,
state_dtype=None,
)#

Argument: x: (seqlen, nheads, headdim) dt: (seqlen, nheads) A: (nheads) B: (seqlen, ngroups, dstate) C: (seqlen, ngroups, dstate) chunk_size: int cu_chunk_seqlens: (nchunks + 1,) last_chunk_indices: (batch,) seq_idx: (nchunks,) out: (seqlen, nheads, headdim) preallocated output tensor D: (nheads, headdim) or (nheads,) z: (seqlen, nheads, headdim) dt_bias: (nheads,) initial_states: (batch, nheads, headdim, dstate) dt_softplus: Whether to apply softplus to dt intermediate_chunk_indices: (N,) optional int64 tensor of chunk indices at which to extract intermediate SSM states. When provided, returns (final_states, intermediate_states) instead of just final_states. state_dtype: The data type of the ssm state

Returns:

(batch, nheads, headdim, dstate), or (varlen_states, intermediate_states) if intermediate_chunk_indices is provided

Return type:

varlen_states