core.ssm.mamba_hybrid_layer_allocation#

Module Contents#

Classes#

Symbols

Symbols for different layer types.

ParsedHybridPattern

Result of parsing a unified hybrid pattern string.

Functions#

parse_hybrid_pattern

Parse a unified hybrid pattern string into main and MTP components.

_validate_pattern

Validate that a pattern contains only valid layer symbols.

_allocate_auto

_allocate_override

_layer_counts_match

allocate_layers

Allocates layers according to the requested distribution of layer types.

get_layer_maps_from_layer_type_list

Returns maps from global layer index to the corresponding layer index for each layer type in [Attention, Mamba, MLP, MoE] given a layer type list.

Data#

API#

core.ssm.mamba_hybrid_layer_allocation.logger#

‘getLogger(…)’

class core.ssm.mamba_hybrid_layer_allocation.Symbols#

Symbols for different layer types.

MAMBA#

‘M’

ATTENTION#

‘*’

MLP#

‘-’

MOE#

‘E’

MTP_SEPARATOR#

‘/’

VALID#

None

class core.ssm.mamba_hybrid_layer_allocation.ParsedHybridPattern#

Result of parsing a unified hybrid pattern string.

A unified pattern encodes both the main decoder pattern and the MTP pattern in a single string using “/” as a separator.

Format: “<main_pattern>/<mtp_pattern>/<mtp_pattern>/…”

.. rubric:: Examples

  • “MM” -> main=”MM”, mtp=None, depths=0 (no MTP)

  • “MM/MM/MM” -> main=”MM”, mtp=”MM”, depths=2

  • “MMMM/*M/*M/*M” -> main=”MMMM”, mtp=”*M”, depths=3

The “/” symbol introduces MTP patterns. Each repeated pattern after the main decoder represents one MTP prediction depth.

.. attribute:: main_pattern

The main decoder layer pattern (e.g., “MM”)

.. attribute:: mtp_pattern

The MTP layer pattern per depth (e.g., “MM”), or None if no MTP

.. attribute:: mtp_num_depths

Number of MTP prediction depths (0 if no MTP)

main_pattern: Optional[str]#

None

mtp_pattern: Optional[str]#

None

mtp_num_depths: int#

None

core.ssm.mamba_hybrid_layer_allocation.parse_hybrid_pattern(
pattern: Optional[str],
) core.ssm.mamba_hybrid_layer_allocation.ParsedHybridPattern#

Parse a unified hybrid pattern string into main and MTP components.

The pattern uses “/” as a separator between the main decoder pattern and MTP patterns. Each MTP pattern after the separator represents one prediction depth.

Format: “<main_pattern>/<mtp_pattern>/<mtp_pattern>/…”

Parameters:

pattern – Unified pattern string, e.g., “MM/MM/MM” or just “MM

Returns:

ParsedHybridPattern with main_pattern, mtp_pattern, and mtp_num_depths

Raises:
  • ValueError – If MTP patterns are inconsistent (all must be identical)

  • ValueError – If pattern contains invalid layer symbols

.. rubric:: Examples

parse_hybrid_pattern(“MM”) ParsedHybridPattern(main_pattern=”MM”, mtp_pattern=None, mtp_num_depths=0)

parse_hybrid_pattern(“MM/MM/MM”) ParsedHybridPattern(main_pattern=”MM”, mtp_pattern=”MM”, mtp_num_depths=2)

parse_hybrid_pattern(“MMMM/*M/*M/*M”) ParsedHybridPattern(main_pattern=”MMMM”, mtp_pattern=”*M”, mtp_num_depths=3)

core.ssm.mamba_hybrid_layer_allocation._validate_pattern(pattern: str, pattern_name: str) None#

Validate that a pattern contains only valid layer symbols.

Parameters:
  • pattern – Layer pattern string to validate

  • pattern_name – Name of pattern for error messages (e.g., “main” or “MTP”)

Raises:

ValueError – If pattern contains invalid symbols

core.ssm.mamba_hybrid_layer_allocation._allocate_auto(
total_layers_count: int,
target_attention_ratio: float,
target_mlp_ratio: float,
) list#
core.ssm.mamba_hybrid_layer_allocation._allocate_override(
total_layers_count: int,
override_pattern: str,
) list#
core.ssm.mamba_hybrid_layer_allocation._layer_counts_match(a: list, b: list) bool#
core.ssm.mamba_hybrid_layer_allocation.allocate_layers(
total_layers_count: int,
target_attention_ratio: float,
target_mlp_ratio: float,
override_pattern: str = None,
silent: bool = False,
) list#

Allocates layers according to the requested distribution of layer types.

core.ssm.mamba_hybrid_layer_allocation.get_layer_maps_from_layer_type_list(
layer_type_list: List[str],
) Tuple[Dict[int, int], Dict[int, int], Dict[int, int]]#

Returns maps from global layer index to the corresponding layer index for each layer type in [Attention, Mamba, MLP, MoE] given a layer type list.