nemo_automodel.components.models.common.gated_delta_net_fp32

View as Markdown

Shared checkpoint helpers for fp32 GatedDeltaNet (GDN) params.

GDN layers carry intrinsically-fp32 bare parameters (A_log and dt_bias) that feed the decay gate g = -exp(A_log) * softplus(a + dt_bias). Under FSDP2 mixed precision with fp32 master weights, the bulk of a model computes in bf16 (param_dtype=bf16) while these parameters must stay in fp32 — A_log is exponentiated, so bf16 rounding becomes a proportional error on the decay rate that the recurrence compounds across the sequence.

Each model owns the runtime construction of its fp32 holder. This module only centralizes the checkpoint contract: hide _fp32_params in saved HF-compatible keys, route bare HF keys back into the holder for native load, and upcast these params to fp32 when checkpoint tensors arrive in a lower precision.

Module Contents

Functions

NameDescription
has_gated_delta_net_fp32_checkpoint_contractReturn whether hf_config belongs to an architecture with fp32 GDN params.
is_gated_delta_net_fp32_param_keyReturn whether key names an intrinsically-fp32 GDN parameter.
route_fp32_holder_keyRewrite a bare ...linear_attn.X GDN param key into the _fp32_params holder.
strip_fp32_holder_keyRewrite ...linear_attn._fp32_params.X -> ...linear_attn.X.
upcast_gated_delta_net_fp32_state_tensorCast loaded GDN fp32-param tensors to fp32 while leaving other state untouched.

Data

FP32_GDN_PARAM_NAMES

GDN_FP32_CHECKPOINT_ARCHITECTURES

HOLDER_NAME

_FP32_HOLDER_KEY_RE

API

nemo_automodel.components.models.common.gated_delta_net_fp32.has_gated_delta_net_fp32_checkpoint_contract(
hf_config: object
) -> bool

Return whether hf_config belongs to an architecture with fp32 GDN params.

nemo_automodel.components.models.common.gated_delta_net_fp32.is_gated_delta_net_fp32_param_key(
key: str,
param_names: tuple[str, ...] = FP32_GDN_PARAM_NAMES
) -> bool

Return whether key names an intrinsically-fp32 GDN parameter.

nemo_automodel.components.models.common.gated_delta_net_fp32.route_fp32_holder_key(
key: str,
param_names: tuple[str, ...] = FP32_GDN_PARAM_NAMES
) -> str

Rewrite a bare ...linear_attn.X GDN param key into the _fp32_params holder.

Inverse of :func:strip_fp32_holder_key for the param names in param_names. No-op when the key is already routed, is not under linear_attn, or is not a tracked fp32 GDN param.

nemo_automodel.components.models.common.gated_delta_net_fp32.strip_fp32_holder_key(
key: str
) -> str

Rewrite ...linear_attn._fp32_params.X -> ...linear_attn.X.

Used by state-dict adapters so saved checkpoints hide the _fp32_params wrapping and stay directly HF-loadable.

nemo_automodel.components.models.common.gated_delta_net_fp32.upcast_gated_delta_net_fp32_state_tensor(
key: str,
tensor: object,
param_names: tuple[str, ...] = FP32_GDN_PARAM_NAMES
) -> object

Cast loaded GDN fp32-param tensors to fp32 while leaving other state untouched.

Construction-time upcasting is not enough for checkpoint and HF load paths that replace or carry tensor values from disk. This helper preserves the fp32 GDN contract at adapter boundaries before tensors enter the live model state dict.

nemo_automodel.components.models.common.gated_delta_net_fp32.FP32_GDN_PARAM_NAMES = ('A_log', 'dt_bias')
nemo_automodel.components.models.common.gated_delta_net_fp32.GDN_FP32_CHECKPOINT_ARCHITECTURES = frozenset(('Qwen3NextForCausalLM', 'Qwen3_5ForCausalLM', 'Qwen3_5ForConditionalG...
nemo_automodel.components.models.common.gated_delta_net_fp32.HOLDER_NAME = '_fp32_params'
nemo_automodel.components.models.common.gated_delta_net_fp32._FP32_HOLDER_KEY_RE = re.compile('(\\.linear_attn)\\._fp32_params\\.')