nemo_automodel.components.models.common.utils
nemo_automodel.components.models.common.utils
Module Contents
Classes
Functions
Data
API
Backend configuration for model components.
Bases: Module
RMSNorm with explicit fp32 computation for training stability.
Weights stay in the model’s dtype (e.g. bf16) for FSDP2 compatibility. Inputs are upcast to fp32, norm is computed in fp32, and the output is cast back to the original input dtype.
Configuration for Transformer Engine FP8 quantization.
When present (not None) in BackendConfig, FP8 is enabled.
The recipe field accepts either a string shorthand ("current", "block",
or "mxfp8") or a pre-built TE recipe object (e.g. Float8CurrentScaling(fp8_dpa=True)).
"mxfp8" selects TE’s :class:MXFP8BlockScaling recipe (e4m3 data + e8m0 block
scales). Unlike torchao’s MXFP8 grouped GEMM, TE’s MXFP8 backward is mature (no
e8m0-overflow NaN), which is why GPT-OSS experts (grouped + bias) use the
experts="te" path with this recipe instead of experts="torch_mm_mxfp8".
Build and return the TE FP8 recipe object.
If recipe is already a TE recipe object (e.g. Float8CurrentScaling(...)),
it is returned directly. String values "current", "block", and
"mxfp8" are mapped to the corresponding TE recipe class.
Return te_autocast context manager for FP8.
Compiled fp32 RMSNorm forward — standalone function to minimize dynamo guards.
Collect module name patterns that must remain in fp32.
Reads _keep_in_fp32_modules and _keep_in_fp32_modules_strict
from the model (the same attributes HuggingFace transformers uses).
Parameters:
The model to inspect.
Returns: list[str]
De-duplicated list of module-name keywords to keep in fp32.
Check if any model parameter is a DTensor (FSDP2 sharded).
Return a callable that patches TE modules exactly once.
Uses a closure instead of module-level global state to track whether the
patch has already been applied. The actual transformer_engine import
is deferred until the first call so that importing this module never
triggers heavy native-library loads (flash-attn, CUDA kernels, etc.).
Two patches are applied:
- Unallocated tensor handling: TE kernels don’t support meta/fake tensors, so we short-circuit with empty tensors for PP shape inference.
- is_first_microbatch injection: Reads the global IS_FIRST_MICROBATCH flag and passes it to TE Linear/GroupedLinear for FP8 weight caching during gradient accumulation (quantize on first microbatch, reuse cached on rest).
Cast only matching buffers (not parameters) back to float32.
Safe for FSDP2-sharded models because buffers are plain tensors, not DTensors managed by FSDP2.
Parameters:
The model (already cast to the target dtype).
Substrings matched against dot-separated module names.
Cast modules or individual tensors matching fp32_keywords back to float32.
Only safe for unsharded models (plain tensors). FSDP2 requires uniform
dtype within each parameter group, so this must not be called on DTensor-sharded
models. Keywords may name modules (for example norm) or individual
parameters (for example attn_hc.fn), matching HuggingFace’s strict fp32
module declarations.
Parameters:
The model (already cast to the target dtype).
Substrings matched against dot-separated module names.
Restore fp32-preserved tensors from pre-cast snapshots.
Clone fp32-preserved tensors before a broad dtype cast.
Casting fp32 -> bf16 -> fp32 restores the dtype but not the original
values. Snapshot the matching tensors first so strict fp32 state such as
router correction bias or recurrent-decay parameters is restored exactly.
Cast the floating-point tensors of frozen submodules to compute_dtype.
When parameters are stored in fp32 (the fp32-master-weights pattern) while compute runs
in bf16, a fully frozen submodule — such as a frozen vision tower — can still produce
fp32 values that flow into bf16 trainable modules and raise a dtype mismatch in the next
matmul. This walks each maximal fully-frozen submodule and casts its parameters and
buffers to compute_dtype, handling the two tensor kinds differently:
- Parameters are cast only when they are plain (unsharded) tensors. Sharded (DTensor)
params are left as-is: FSDP all-gathers them to the compute dtype during forward, and
changing a sharded param’s dtype in place would desync FSDP’s flat-parameter and
orig_dtypebookkeeping. - Buffers are always cast. Buffers are never sharded, so they stay in their stored dtype regardless of the wrapper; an fp32 buffer (for example a standardization constant) used in a forward op promotes the surrounding bf16 activations to fp32.
Tensors whose qualified name matches _keep_in_fp32_modules or
_keep_in_fp32_modules_strict are left in fp32. The function is a no-op when
compute_dtype is None and for tensors already in compute_dtype. Frozen modules are
never updated, so casting them does not affect training accuracy.
Parameters:
The model, already materialized, checkpoint-loaded, and sharded.
The compute dtype (mp_policy.param_dtype); None disables the cast.
Cast model parameters to the target dtype, keeping fp32 modules in full precision.
Respects _keep_in_fp32_modules / _keep_in_fp32_modules_strict on
the model (the same attributes HuggingFace transformers uses).
Uses nn.Module.to() which is safe for both plain tensors and DTensors
(FSDP2 sharded parameters). When the model is already FSDP2-sharded
(parameters are DTensors), strict fp32 modules are restored to fp32 because
they are expected to be isolated as uniform fp32 FSDP units. Non-strict fp32
hints only restore matching buffers, since their parameters may share an
FSDP unit with lower-precision parameters.
Parameters:
The model whose parameters should be cast.
Target dtype (e.g. torch.bfloat16).
Names of immediate submodules to leave entirely untouched
(kept at their current dtype). Unlike the _keep_in_fp32_modules
restore path, these are detached during the cast so model.to()
never visits them — the only reliable way to preserve an fp32
parameter once it is FSDP2-sharded (post-shard .data reassignment
does not stick). The caller must guarantee each skipped submodule is
its own dtype-uniform FSDP group (e.g. Qwen3.5’s _fp32_params
holder, sharded separately in fp32), so leaving it fp32 cannot break
FSDP’s uniform-dtype rule.
Project hidden states through lm_head and wrap the result.
Centralizes the lm_head projection and output packaging shared by every
custom *ForCausalLM / *ForConditionalGeneration forward(). The
returned CausalLMOutputWithPast carries the projected logits and,
when requested, the final hidden_states; callers that also need loss,
past_key_values, etc. read .logits and build their own output.
lm_head is None(e.g. a non-final pipeline-parallel stage that does not own the head):hidden_statesis passed through aslogitsso the next stage receives it.logits_to_keep == 0(training default): every position is projected. The full range is deliberately not sliced, becauseslice(0, None)on a DTensor is unsupported (it raises on theaten.aliasop under tensor parallel with sequence parallelism).logits_to_keepas a positive int or a tensor of indices: only the requested positions are projected. Both 2D[T, H](THD/packed) and 3D[B, S, H](BSHD) hidden states are handled.is_thd: THD/packed inputs yield 2D[T, V]logits; the leading batch dim is restored (unsqueeze(0)->[1, T, V]) so downstream code sees a uniform[B, S, V]layout. The same restoration is applied to thehidden_statesfield. Only applied while the tensor is still 2D, so aninputs_embedspath that already produced[1, T, *]is left untouched.fp32_lm_head: run the projection in fp32 and cast the logits back to the input dtype. Used by models whoselm_head.weighthas been promoted to fp32 (e.g. via the MoElm_head_precisionsetting). The matmul goes throughlm_head(nn.Linear, DTensor-aware under FSDP2) rather thanF.linearso DTensor redistribution is preserved.output_hidden_states: when set, the (full-sequence, THD-restored)hidden_statesare attached to the output so the fused cross-entropy path can recompute logits over every position; otherwise the field isNone.
Parameters:
The language-model head module, or None on a pipeline stage
that does not own it.
Final hidden states, shaped [T, H] or [B, S, H].
0 to project every position; a positive int to keep
the last N positions; or a tensor of position indices.
Whether the inputs were THD/packed; if so, a 2D logits (and hidden-states) result is unsqueezed back to a leading batch dim of 1.
Project in fp32 and cast the result back to the input
dtype. Ignored when lm_head is None.
Attach the final hidden states to the output.
Returns: CausalLMOutputWithPast
A CausalLMOutputWithPast whose logits are the projected logits
Get the global IS_FIRST_MICROBATCH flag.
Returns: bool | None
True/False/None indicating microbatch position for FP8 weight caching.
Get the global IS_OPTIM_STEP flag.
Returns: bool
Whether we are in an optimization step.
Extract rope configuration from config.rope_parameters.
Parameters:
A HuggingFace model config object.
Returns: tuple[float, dict, float]
Tuple of (rope_theta, rope_parameters, partial_rotary_factor).
Initialize Linear module with the specified backend.
For TE backend, creates TE module directly on specified device. Call reset_parameters() to materialize weights if created on meta device.
Parameters:
Backend implementation (“te” or “torch”)
Input features
Output features
Whether to use bias
Device to create module on (None uses PyTorch default, typically CPU)
Parameter dtype
Returns: nn.Module
Linear module
Initialize RMSNorm module with the specified backend.
For TE backend, creates TE module directly on specified device. Call reset_parameters() to materialize weights if created on meta device.
Parameters:
Backend implementation (“te”, “torch”, or “torch_fp32”)
- “te”: Transformer Engine fused RMSNorm kernel
- “torch”: PyTorch native nn.RMSNorm (computes in input dtype)
- “torch_fp32”: torch.compiled fp32 RMSNorm for training stability
Normalized dimension
Epsilon for numerical stability
Device to create module on (None uses PyTorch default, typically CPU)
Parameter dtype
Returns: nn.Module
RMSNorm module
Check if tensor is unallocated (meta tensor, fake tensor, etc.).
TE kernels don’t support meta tensors, fake tensors, or unallocated tensors. This helper detects such cases for fallback handling.
Parameters:
Tensor to check
Returns: bool
True if tensor is unallocated or cannot be accessed
Set the global IS_FIRST_MICROBATCH flag for FP8 weight caching.
Parameters:
True for first microbatch (quantize+cache), False for subsequent (use cached), None to disable caching.
Set the global IS_OPTIM_STEP flag.
Parameters:
Whether we are in an optimization step.
Run a block with the model temporarily in fp32, then cast it to restore_dtype.
On entry the whole model is cast to fp32; on exit it is cast to restore_dtype
(which defaults to the model’s pre-context floating-point dtype, so by default the
original dtype is restored). The exit cast is a no-op when the target is already fp32.
The motivating use is from-scratch weight initialization. Sampling a random init directly
in a reduced-precision dtype (e.g. bf16) distorts the init’s variance/mean schedule: bf16’s
8-bit mantissa quantizes the small init magnitudes and biases the truncation/scaling
arithmetic used by normal_ / trunc_normal_. In a deep residual stack this compounds
and produces genuinely huge gradients on the first optimization steps of from-scratch
pretraining (flat / diverging loss). Sampling in fp32 and then casting back avoids this while
keeping reduced-precision storage: the round-to-bf16 of a correct fp32 sample is an unbiased
per-element perturbation that preserves the init statistics. Wrap the body of a model’s
initialize_weights to keep that round-trip in one place.
Works whether or not the model is already FSDP2-sharded: both casts are uniform whole-model
casts, so FSDP2’s invariant that every parameter in a group shares one dtype is preserved. In
the AutoModel pipeline initialize_weights actually runs after sharding (via
checkpointer.initialize_model_weights), i.e. on DTensor params, which is supported.
_keep_in_fp32_modules / _keep_in_fp32_modules_strict handling is delegated to
cast_model_to_dtype: on an unsharded model those modules’ params and buffers are restored
to fp32 on exit; on a sharded model, strict fp32 modules are restored while non-strict modules
only have their buffers restored.
Parameters:
The model to run in fp32 within the context.
The dtype to cast the model to on exit. Defaults to the model’s current floating-point dtype (captured before the fp32 cast), i.e. the original dtype.