nemo_automodel.components.models.common.gated_delta_net_fp32
nemo_automodel.components.models.common.gated_delta_net_fp32
Shared checkpoint helpers for fp32 GatedDeltaNet (GDN) params.
GDN layers carry intrinsically-fp32 bare parameters (A_log and dt_bias)
that feed the decay gate g = -exp(A_log) * softplus(a + dt_bias). Under FSDP2
mixed precision with fp32 master weights, the bulk of a model computes in bf16
(param_dtype=bf16) while these parameters must stay in fp32 — A_log is
exponentiated, so bf16 rounding becomes a proportional error on the decay rate
that the recurrence compounds across the sequence.
Each model owns the runtime construction of its fp32 holder. This module only
centralizes the checkpoint contract: hide _fp32_params in saved HF-compatible
keys, route bare HF keys back into the holder for native load, and upcast these
params to fp32 when checkpoint tensors arrive in a lower precision.
Module Contents
Functions
Data
GDN_FP32_CHECKPOINT_ARCHITECTURES
API
Return whether hf_config belongs to an architecture with fp32 GDN params.
Return whether key names an intrinsically-fp32 GDN parameter.
Rewrite a bare ...linear_attn.X GDN param key into the _fp32_params holder.
Inverse of :func:strip_fp32_holder_key for the param names in param_names.
No-op when the key is already routed, is not under linear_attn, or is not a
tracked fp32 GDN param.
Rewrite ...linear_attn._fp32_params.X -> ...linear_attn.X.
Used by state-dict adapters so saved checkpoints hide the _fp32_params
wrapping and stay directly HF-loadable.
Cast loaded GDN fp32-param tensors to fp32 while leaving other state untouched.
Construction-time upcasting is not enough for checkpoint and HF load paths that replace or carry tensor values from disk. This helper preserves the fp32 GDN contract at adapter boundaries before tensors enter the live model state dict.