nemo_automodel.components.models.common.utils

View as Markdown

Module Contents

Classes

NameDescription
BackendConfigBackend configuration for model components.
Float32RMSNormRMSNorm with explicit fp32 computation for training stability.
TEFp8ConfigConfiguration for Transformer Engine FP8 quantization.

Functions

NameDescription
_float32_rms_norm_fwdCompiled fp32 RMSNorm forward — standalone function to minimize dynamo guards.
_get_fp32_module_keywordsCollect module name patterns that must remain in fp32.
_get_strict_fp32_module_keywords-
_has_dtensor_paramsCheck if any model parameter is a DTensor (FSDP2 sharded).
_make_lazy_te_patcherReturn a callable that patches TE modules exactly once.
_restore_fp32_buffersCast only matching buffers (not parameters) back to float32.
_restore_fp32_modulesCast modules or individual tensors matching fp32_keywords back to float32.
_restore_fp32_tensor_snapshotsRestore fp32-preserved tensors from pre-cast snapshots.
_snapshot_fp32_tensorsClone fp32-preserved tensors before a broad dtype cast.
cast_frozen_modules_to_compute_dtypeCast the floating-point tensors of frozen submodules to compute_dtype.
cast_model_to_dtypeCast model parameters to the target dtype, keeping fp32 modules in full precision.
compute_lm_head_logitsProject hidden states through lm_head and wrap the result.
get_is_first_microbatchGet the global IS_FIRST_MICROBATCH flag.
get_is_optim_stepGet the global IS_OPTIM_STEP flag.
get_rope_configExtract rope configuration from config.rope_parameters.
initialize_linear_moduleInitialize Linear module with the specified backend.
initialize_rms_norm_moduleInitialize RMSNorm module with the specified backend.
is_tensor_unallocatedCheck if tensor is unallocated (meta tensor, fake tensor, etc.).
set_is_first_microbatchSet the global IS_FIRST_MICROBATCH flag for FP8 weight caching.
set_is_optim_stepSet the global IS_OPTIM_STEP flag.
yield_fp32_modelRun a block with the model temporarily in fp32, then cast it to restore_dtype.

Data

HAVE_DEEP_EP

HAVE_GMM

HAVE_TE

HAVE_UCCL_EP

IS_FIRST_MICROBATCH

IS_OPTIM_STEP

__all__

_patch_te_modules

logger

API

class nemo_automodel.components.models.common.utils.BackendConfig(
attn: typing.Literal['te', 'sdpa', 'flex', 'eager', 'tilelang'] = 'te' if HAVE_TE and torch.c...,
linear: typing.Literal['torch', 'te'] = 'te' if HAVE_TE and torch.c...,
rms_norm: typing.Literal['torch', 'torch_fp32', 'te'] = 'torch_fp32',
rope_fusion: bool = HAVE_TE and torch.cuda.is_a...,
experts: typing.Literal['torch', 'te', 'gmm', 'torch_mm', 'torch_mm_mxfp8'] = 'torch_mm' if torch.cuda.is...,
dispatcher: typing.Literal['torch', 'deepep', 'hybridep', 'uccl_ep'] = 'deepep' if HAVE_DEEP_EP an...,
dispatcher_num_sms: int = 20,
dispatcher_share_token_dispatcher: bool = True,
dispatcher_async_dispatch: bool = False,
enable_deepep: bool | None = None,
fake_balanced_gate: bool = False,
fake_gate_noise: float = 0.0,
enable_hf_state_dict_adapter: bool = True,
enable_fsdp_optimizations: bool = False,
te_fp8: nemo_automodel.components.models.common.utils.TEFp8Config | None = None,
gate_precision: str | torch.dtype | None = None,
compile_attn: bool = False
)
Dataclass

Backend configuration for model components.

attn
Literal['te', 'sdpa', 'flex', 'eager', 'tilelang']
compile_attn
bool = False
dispatcher
Literal['torch', 'deepep', 'hybridep', 'uccl_ep']
dispatcher_async_dispatch
bool = False
dispatcher_num_sms
int = 20
dispatcher_share_token_dispatcher
bool = True
enable_deepep
bool | None = None
enable_fsdp_optimizations
bool = False
enable_hf_state_dict_adapter
bool = True
experts
Literal['torch', 'te', 'gmm', 'torch_mm', 'torch_mm_mxfp8']
fake_balanced_gate
bool = False
fake_gate_noise
float = 0.0
gate_precision
str | dtype | None = None
linear
Literal['torch', 'te']
rms_norm
Literal['torch', 'torch_fp32', 'te'] = 'torch_fp32'
rope_fusion
bool = HAVE_TE and torch.cuda.is_available()
te_fp8
TEFp8Config | None = None
nemo_automodel.components.models.common.utils.BackendConfig.__post_init__()
class nemo_automodel.components.models.common.utils.Float32RMSNorm(
dim,
eps = 1e-05,
device = None,
dtype = torch.bfloat16
)

Bases: Module

RMSNorm with explicit fp32 computation for training stability.

Weights stay in the model’s dtype (e.g. bf16) for FSDP2 compatibility. Inputs are upcast to fp32, norm is computed in fp32, and the output is cast back to the original input dtype.

weight
nemo_automodel.components.models.common.utils.Float32RMSNorm.forward(
x
)
nemo_automodel.components.models.common.utils.Float32RMSNorm.reset_parameters()
class nemo_automodel.components.models.common.utils.TEFp8Config(
recipe: typing.Literal['current', 'block', 'mxfp8'] | typing.Any = 'current'
)
Dataclass

Configuration for Transformer Engine FP8 quantization.

When present (not None) in BackendConfig, FP8 is enabled. The recipe field accepts either a string shorthand ("current", "block", or "mxfp8") or a pre-built TE recipe object (e.g. Float8CurrentScaling(fp8_dpa=True)).

"mxfp8" selects TE’s :class:MXFP8BlockScaling recipe (e4m3 data + e8m0 block scales). Unlike torchao’s MXFP8 grouped GEMM, TE’s MXFP8 backward is mature (no e8m0-overflow NaN), which is why GPT-OSS experts (grouped + bias) use the experts="te" path with this recipe instead of experts="torch_mm_mxfp8".

recipe
Literal['current', 'block', 'mxfp8'] | Any = 'current'
nemo_automodel.components.models.common.utils.TEFp8Config.build_recipe()

Build and return the TE FP8 recipe object.

If recipe is already a TE recipe object (e.g. Float8CurrentScaling(...)), it is returned directly. String values "current", "block", and "mxfp8" are mapped to the corresponding TE recipe class.

nemo_automodel.components.models.common.utils.TEFp8Config.maybe_te_autocast()

Return te_autocast context manager for FP8.

nemo_automodel.components.models.common.utils._float32_rms_norm_fwd(
x: torch.Tensor,
weight: torch.Tensor,
eps: float
) -> torch.Tensor

Compiled fp32 RMSNorm forward — standalone function to minimize dynamo guards.

nemo_automodel.components.models.common.utils._get_fp32_module_keywords(
model: torch.nn.Module
) -> list[str]

Collect module name patterns that must remain in fp32.

Reads _keep_in_fp32_modules and _keep_in_fp32_modules_strict from the model (the same attributes HuggingFace transformers uses).

Parameters:

model
nn.Module

The model to inspect.

Returns: list[str]

De-duplicated list of module-name keywords to keep in fp32.

nemo_automodel.components.models.common.utils._get_strict_fp32_module_keywords(
model: torch.nn.Module
) -> list[str]
nemo_automodel.components.models.common.utils._has_dtensor_params(
model: torch.nn.Module
) -> bool

Check if any model parameter is a DTensor (FSDP2 sharded).

nemo_automodel.components.models.common.utils._make_lazy_te_patcher()

Return a callable that patches TE modules exactly once.

Uses a closure instead of module-level global state to track whether the patch has already been applied. The actual transformer_engine import is deferred until the first call so that importing this module never triggers heavy native-library loads (flash-attn, CUDA kernels, etc.).

Two patches are applied:

  1. Unallocated tensor handling: TE kernels don’t support meta/fake tensors, so we short-circuit with empty tensors for PP shape inference.
  2. is_first_microbatch injection: Reads the global IS_FIRST_MICROBATCH flag and passes it to TE Linear/GroupedLinear for FP8 weight caching during gradient accumulation (quantize on first microbatch, reuse cached on rest).
nemo_automodel.components.models.common.utils._restore_fp32_buffers(
model: torch.nn.Module,
fp32_keywords: list[str]
) -> None

Cast only matching buffers (not parameters) back to float32.

Safe for FSDP2-sharded models because buffers are plain tensors, not DTensors managed by FSDP2.

Parameters:

model
nn.Module

The model (already cast to the target dtype).

fp32_keywords
list[str]

Substrings matched against dot-separated module names.

nemo_automodel.components.models.common.utils._restore_fp32_modules(
model: torch.nn.Module,
fp32_keywords: list[str]
) -> None

Cast modules or individual tensors matching fp32_keywords back to float32.

Only safe for unsharded models (plain tensors). FSDP2 requires uniform dtype within each parameter group, so this must not be called on DTensor-sharded models. Keywords may name modules (for example norm) or individual parameters (for example attn_hc.fn), matching HuggingFace’s strict fp32 module declarations.

Parameters:

model
nn.Module

The model (already cast to the target dtype).

fp32_keywords
list[str]

Substrings matched against dot-separated module names.

nemo_automodel.components.models.common.utils._restore_fp32_tensor_snapshots(
model: torch.nn.Module,
parameter_snapshots: dict[str, torch.Tensor],
buffer_snapshots: dict[str, torch.Tensor]
) -> None

Restore fp32-preserved tensors from pre-cast snapshots.

nemo_automodel.components.models.common.utils._snapshot_fp32_tensors(
model: torch.nn.Module,
parameter_keywords: list[str],
buffer_keywords: list[str]
) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]

Clone fp32-preserved tensors before a broad dtype cast.

Casting fp32 -> bf16 -> fp32 restores the dtype but not the original values. Snapshot the matching tensors first so strict fp32 state such as router correction bias or recurrent-decay parameters is restored exactly.

nemo_automodel.components.models.common.utils.cast_frozen_modules_to_compute_dtype(
model: torch.nn.Module,
compute_dtype: torch.dtype | None
) -> None

Cast the floating-point tensors of frozen submodules to compute_dtype.

When parameters are stored in fp32 (the fp32-master-weights pattern) while compute runs in bf16, a fully frozen submodule — such as a frozen vision tower — can still produce fp32 values that flow into bf16 trainable modules and raise a dtype mismatch in the next matmul. This walks each maximal fully-frozen submodule and casts its parameters and buffers to compute_dtype, handling the two tensor kinds differently:

  • Parameters are cast only when they are plain (unsharded) tensors. Sharded (DTensor) params are left as-is: FSDP all-gathers them to the compute dtype during forward, and changing a sharded param’s dtype in place would desync FSDP’s flat-parameter and orig_dtype bookkeeping.
  • Buffers are always cast. Buffers are never sharded, so they stay in their stored dtype regardless of the wrapper; an fp32 buffer (for example a standardization constant) used in a forward op promotes the surrounding bf16 activations to fp32.

Tensors whose qualified name matches _keep_in_fp32_modules or _keep_in_fp32_modules_strict are left in fp32. The function is a no-op when compute_dtype is None and for tensors already in compute_dtype. Frozen modules are never updated, so casting them does not affect training accuracy.

Parameters:

model
nn.Module

The model, already materialized, checkpoint-loaded, and sharded.

compute_dtype
torch.dtype | None

The compute dtype (mp_policy.param_dtype); None disables the cast.

nemo_automodel.components.models.common.utils.cast_model_to_dtype(
model: torch.nn.Module,
dtype: torch.dtype = torch.bfloat16,
skip_modules: tuple[str, ...] = ()
) -> None

Cast model parameters to the target dtype, keeping fp32 modules in full precision.

Respects _keep_in_fp32_modules / _keep_in_fp32_modules_strict on the model (the same attributes HuggingFace transformers uses).

Uses nn.Module.to() which is safe for both plain tensors and DTensors (FSDP2 sharded parameters). When the model is already FSDP2-sharded (parameters are DTensors), strict fp32 modules are restored to fp32 because they are expected to be isolated as uniform fp32 FSDP units. Non-strict fp32 hints only restore matching buffers, since their parameters may share an FSDP unit with lower-precision parameters.

Parameters:

model
nn.Module

The model whose parameters should be cast.

dtype
torch.dtypeDefaults to torch.bfloat16

Target dtype (e.g. torch.bfloat16).

skip_modules
tuple[str, ...]Defaults to ()

Names of immediate submodules to leave entirely untouched (kept at their current dtype). Unlike the _keep_in_fp32_modules restore path, these are detached during the cast so model.to() never visits them — the only reliable way to preserve an fp32 parameter once it is FSDP2-sharded (post-shard .data reassignment does not stick). The caller must guarantee each skipped submodule is its own dtype-uniform FSDP group (e.g. Qwen3.5’s _fp32_params holder, sharded separately in fp32), so leaving it fp32 cannot break FSDP’s uniform-dtype rule.

nemo_automodel.components.models.common.utils.compute_lm_head_logits(
lm_head: torch.nn.Module | None,
hidden_states: torch.Tensor,
logits_to_keep: int | torch.Tensor = 0,
is_thd: bool = False,
fp32_lm_head: bool = False,
output_hidden_states: bool = False
) -> transformers.modeling_outputs.CausalLMOutputWithPast

Project hidden states through lm_head and wrap the result.

Centralizes the lm_head projection and output packaging shared by every custom *ForCausalLM / *ForConditionalGeneration forward(). The returned CausalLMOutputWithPast carries the projected logits and, when requested, the final hidden_states; callers that also need loss, past_key_values, etc. read .logits and build their own output.

  • lm_head is None (e.g. a non-final pipeline-parallel stage that does not own the head): hidden_states is passed through as logits so the next stage receives it.
  • logits_to_keep == 0 (training default): every position is projected. The full range is deliberately not sliced, because slice(0, None) on a DTensor is unsupported (it raises on the aten.alias op under tensor parallel with sequence parallelism).
  • logits_to_keep as a positive int or a tensor of indices: only the requested positions are projected. Both 2D [T, H] (THD/packed) and 3D [B, S, H] (BSHD) hidden states are handled.
  • is_thd: THD/packed inputs yield 2D [T, V] logits; the leading batch dim is restored (unsqueeze(0) -> [1, T, V]) so downstream code sees a uniform [B, S, V] layout. The same restoration is applied to the hidden_states field. Only applied while the tensor is still 2D, so an inputs_embeds path that already produced [1, T, *] is left untouched.
  • fp32_lm_head: run the projection in fp32 and cast the logits back to the input dtype. Used by models whose lm_head.weight has been promoted to fp32 (e.g. via the MoE lm_head_precision setting). The matmul goes through lm_head (nn.Linear, DTensor-aware under FSDP2) rather than F.linear so DTensor redistribution is preserved.
  • output_hidden_states: when set, the (full-sequence, THD-restored) hidden_states are attached to the output so the fused cross-entropy path can recompute logits over every position; otherwise the field is None.

Parameters:

lm_head
nn.Module | None

The language-model head module, or None on a pipeline stage that does not own it.

hidden_states
torch.Tensor

Final hidden states, shaped [T, H] or [B, S, H].

logits_to_keep
int | torch.TensorDefaults to 0

0 to project every position; a positive int to keep the last N positions; or a tensor of position indices.

is_thd
boolDefaults to False

Whether the inputs were THD/packed; if so, a 2D logits (and hidden-states) result is unsqueezed back to a leading batch dim of 1.

fp32_lm_head
boolDefaults to False

Project in fp32 and cast the result back to the input dtype. Ignored when lm_head is None.

output_hidden_states
boolDefaults to False

Attach the final hidden states to the output.

Returns: CausalLMOutputWithPast

A CausalLMOutputWithPast whose logits are the projected logits

nemo_automodel.components.models.common.utils.get_is_first_microbatch() -> bool | None

Get the global IS_FIRST_MICROBATCH flag.

Returns: bool | None

True/False/None indicating microbatch position for FP8 weight caching.

nemo_automodel.components.models.common.utils.get_is_optim_step() -> bool

Get the global IS_OPTIM_STEP flag.

Returns: bool

Whether we are in an optimization step.

nemo_automodel.components.models.common.utils.get_rope_config(
config
) -> tuple[float, dict, float]

Extract rope configuration from config.rope_parameters.

Parameters:

config

A HuggingFace model config object.

Returns: tuple[float, dict, float]

Tuple of (rope_theta, rope_parameters, partial_rotary_factor).

nemo_automodel.components.models.common.utils.initialize_linear_module(
linear_impl: str,
in_features: int,
out_features: int,
bias: bool = False,
device: torch.device | str | None = None,
dtype: torch.dtype = torch.bfloat16
) -> torch.nn.Module

Initialize Linear module with the specified backend.

For TE backend, creates TE module directly on specified device. Call reset_parameters() to materialize weights if created on meta device.

Parameters:

linear_impl
str

Backend implementation (“te” or “torch”)

in_features
int

Input features

out_features
int

Output features

bias
boolDefaults to False

Whether to use bias

device
torch.device | str | NoneDefaults to None

Device to create module on (None uses PyTorch default, typically CPU)

dtype
torch.dtypeDefaults to torch.bfloat16

Parameter dtype

Returns: nn.Module

Linear module

nemo_automodel.components.models.common.utils.initialize_rms_norm_module(
rms_norm_impl: str,
dim: int,
eps: float = 1e-05,
device: torch.device | str | None = None,
dtype: torch.dtype = torch.bfloat16
) -> torch.nn.Module

Initialize RMSNorm module with the specified backend.

For TE backend, creates TE module directly on specified device. Call reset_parameters() to materialize weights if created on meta device.

Parameters:

rms_norm_impl
str

Backend implementation (“te”, “torch”, or “torch_fp32”)

  • “te”: Transformer Engine fused RMSNorm kernel
  • “torch”: PyTorch native nn.RMSNorm (computes in input dtype)
  • “torch_fp32”: torch.compiled fp32 RMSNorm for training stability
dim
int

Normalized dimension

eps
floatDefaults to 1e-05

Epsilon for numerical stability

device
torch.device | str | NoneDefaults to None

Device to create module on (None uses PyTorch default, typically CPU)

dtype
torch.dtypeDefaults to torch.bfloat16

Parameter dtype

Returns: nn.Module

RMSNorm module

nemo_automodel.components.models.common.utils.is_tensor_unallocated(
tensor: torch.Tensor
) -> bool

Check if tensor is unallocated (meta tensor, fake tensor, etc.).

TE kernels don’t support meta tensors, fake tensors, or unallocated tensors. This helper detects such cases for fallback handling.

Parameters:

tensor
torch.Tensor

Tensor to check

Returns: bool

True if tensor is unallocated or cannot be accessed

nemo_automodel.components.models.common.utils.set_is_first_microbatch(
value: bool | None
) -> None

Set the global IS_FIRST_MICROBATCH flag for FP8 weight caching.

Parameters:

value
bool | None

True for first microbatch (quantize+cache), False for subsequent (use cached), None to disable caching.

nemo_automodel.components.models.common.utils.set_is_optim_step(
value: bool
) -> None

Set the global IS_OPTIM_STEP flag.

Parameters:

value
bool

Whether we are in an optimization step.

nemo_automodel.components.models.common.utils.yield_fp32_model(
model: torch.nn.Module,
restore_dtype: torch.dtype | None = None
)

Run a block with the model temporarily in fp32, then cast it to restore_dtype.

On entry the whole model is cast to fp32; on exit it is cast to restore_dtype (which defaults to the model’s pre-context floating-point dtype, so by default the original dtype is restored). The exit cast is a no-op when the target is already fp32.

The motivating use is from-scratch weight initialization. Sampling a random init directly in a reduced-precision dtype (e.g. bf16) distorts the init’s variance/mean schedule: bf16’s 8-bit mantissa quantizes the small init magnitudes and biases the truncation/scaling arithmetic used by normal_ / trunc_normal_. In a deep residual stack this compounds and produces genuinely huge gradients on the first optimization steps of from-scratch pretraining (flat / diverging loss). Sampling in fp32 and then casting back avoids this while keeping reduced-precision storage: the round-to-bf16 of a correct fp32 sample is an unbiased per-element perturbation that preserves the init statistics. Wrap the body of a model’s initialize_weights to keep that round-trip in one place.

Works whether or not the model is already FSDP2-sharded: both casts are uniform whole-model casts, so FSDP2’s invariant that every parameter in a group shares one dtype is preserved. In the AutoModel pipeline initialize_weights actually runs after sharding (via checkpointer.initialize_model_weights), i.e. on DTensor params, which is supported.

_keep_in_fp32_modules / _keep_in_fp32_modules_strict handling is delegated to cast_model_to_dtype: on an unsharded model those modules’ params and buffers are restored to fp32 on exit; on a sharded model, strict fp32 modules are restored while non-strict modules only have their buffers restored.

Parameters:

model
nn.Module

The model to run in fp32 within the context.

restore_dtype
torch.dtype | NoneDefaults to None

The dtype to cast the model to on exit. Defaults to the model’s current floating-point dtype (captured before the fp32 cast), i.e. the original dtype.

nemo_automodel.components.models.common.utils.HAVE_DEEP_EP = importlib.util.find_spec('deep_ep') is not None
nemo_automodel.components.models.common.utils.HAVE_GMM = importlib.util.find_spec('grouped_gemm') is not None
nemo_automodel.components.models.common.utils.HAVE_TE = importlib.util.find_spec('transformer_engine') is not None
nemo_automodel.components.models.common.utils.HAVE_UCCL_EP = importlib.util.find_spec('uccl') is not None or importlib.util.find_spec('ep') i...
nemo_automodel.components.models.common.utils.IS_FIRST_MICROBATCH: bool | None = None
nemo_automodel.components.models.common.utils.IS_OPTIM_STEP = False
nemo_automodel.components.models.common.utils.__all__ = ['BackendConfig', 'Float32RMSNorm', 'TEFp8Config', 'cast_frozen_modules_to_compu...
nemo_automodel.components.models.common.utils._patch_te_modules = _make_lazy_te_patcher()
nemo_automodel.components.models.common.utils.logger = logging.getLogger(__name__)