core.ssm.ops.ssd_combined#
Module Contents#
Functions#
Return True if n is a positive integer power of 2. |
|
Argument: x: (seqlen, nheads, headdim) dt: (seqlen, nheads) A: (nheads) B: (seqlen, ngroups, dstate) C: (seqlen, ngroups, dstate) chunk_size: int cu_chunk_seqlens: (nchunks + 1,) last_chunk_indices: (batch,) seq_idx: (nchunks,) out: (seqlen, nheads, headdim) preallocated output tensor D: (nheads, headdim) or (nheads,) z: (seqlen, nheads, headdim) dt_bias: (nheads,) initial_states: (batch, nheads, headdim, dstate) dt_softplus: Whether to apply softplus to dt intermediate_chunk_indices: (N,) optional int64 tensor of chunk indices at which to extract intermediate SSM states. When provided, returns (final_states, intermediate_states) instead of just final_states. state_dtype: The data type of the ssm state |
Data#
API#
- core.ssm.ops.ssd_combined.TRITON_22#
None
- core.ssm.ops.ssd_combined.is_int_pow_2(n)#
Return True if n is a positive integer power of 2.
- core.ssm.ops.ssd_combined._mamba_chunk_scan_combined_fwd(
- x,
- dt,
- A,
- B,
- C,
- chunk_size,
- out,
- D=None,
- z=None,
- dt_bias=None,
- initial_states=None,
- return_intermediate_states=False,
- seq_idx=None,
- cu_chunk_seqlens=None,
- last_chunk_indices=None,
- intermediate_chunk_indices=None,
- dt_softplus=False,
- dt_limit=(0.0, float('inf')),
- state_dtype=None,
- core.ssm.ops.ssd_combined.mamba_chunk_scan_combined_varlen(
- x,
- dt,
- A,
- B,
- C,
- chunk_size,
- cu_chunk_seqlens,
- last_chunk_indices,
- seq_idx,
- out,
- D=None,
- z=None,
- dt_bias=None,
- initial_states=None,
- dt_softplus=False,
- dt_limit=(0.0, float('inf')),
- return_intermediate_states=False,
- intermediate_chunk_indices=None,
- state_dtype=None,
Argument: x: (seqlen, nheads, headdim) dt: (seqlen, nheads) A: (nheads) B: (seqlen, ngroups, dstate) C: (seqlen, ngroups, dstate) chunk_size: int cu_chunk_seqlens: (nchunks + 1,) last_chunk_indices: (batch,) seq_idx: (nchunks,) out: (seqlen, nheads, headdim) preallocated output tensor D: (nheads, headdim) or (nheads,) z: (seqlen, nheads, headdim) dt_bias: (nheads,) initial_states: (batch, nheads, headdim, dstate) dt_softplus: Whether to apply softplus to dt intermediate_chunk_indices: (N,) optional int64 tensor of chunk indices at which to extract intermediate SSM states. When provided, returns (final_states, intermediate_states) instead of just final_states. state_dtype: The data type of the ssm state
- Returns:
(batch, nheads, headdim, dstate), or (varlen_states, intermediate_states) if intermediate_chunk_indices is provided
- Return type:
varlen_states