core.post_training.modelopt.mamba.model_specs#
Module Contents#
Functions#
Get the Mamba stack spec for ModelOpt PTQ and TensorRT-LLM export. |
|
Get the Mamba stack spec with local (non-TE) modules. |
API#
- core.post_training.modelopt.mamba.model_specs.get_mamba_stack_modelopt_spec(
- local_core_attention: bool = False,
- remap_te_layernorm: bool = False,
- use_default_te_spec: bool = False,
Get the Mamba stack spec for ModelOpt PTQ and TensorRT-LLM export.
When use_default_te_spec=False (default), this is the native local spec with TENorm from Transformer-Engine for the layernorm implementation (since FusedLayerNorm from apex has stopped supporting RMSNorm needed by llama). The remap_te_layernorm flag can be used to add sharded state_dict key remapping for TE-compatible checkpoint saving/loading.
When use_default_te_spec=True, this returns the standard mamba_stack_spec from mamba_layer_specs.py which uses full TE modules (TELayerNormColumnParallelLinear, TERowParallelLinear, TEDotProductAttention, TENorm, moe_grouped_gemm=True).
- Parameters:
local_core_attention – whether to use local DotProductAttention (only for use_default_te_spec=False)
remap_te_layernorm – whether to perform sharded state_dict prefix mapping on layernorm (only for use_default_te_spec=False)
use_default_te_spec – whether to use the default Transformer-Engine spec
- core.post_training.modelopt.mamba.model_specs._get_mamba_stack_local_spec(
- local_core_attention: bool = False,
- remap_te_layernorm: bool = False,
Get the Mamba stack spec with local (non-TE) modules.
This is essentially the native local spec except for the layernorm implementation is using TENorm from Transformer-Engine.