nemo_automodel.components.models.common.utils#

Module Contents#

Classes#

TEFp8Config

Configuration for Transformer Engine FP8 quantization.

BackendConfig

Backend configuration for model components.

Float32RMSNorm

RMSNorm with explicit fp32 computation for training stability.

Functions#

set_is_optim_step

Set the global IS_OPTIM_STEP flag.

get_is_optim_step

Get the global IS_OPTIM_STEP flag.

set_is_first_microbatch

Set the global IS_FIRST_MICROBATCH flag for FP8 weight caching.

get_is_first_microbatch

Get the global IS_FIRST_MICROBATCH flag.

is_tensor_unallocated

Check if tensor is unallocated (meta tensor, fake tensor, etc.).

initialize_rms_norm_module

Initialize RMSNorm module with the specified backend.

initialize_linear_module

Initialize Linear module with the specified backend.

_make_lazy_te_patcher

Return a callable that patches TE modules exactly once.

get_rope_config

Extract rope configuration from config.rope_parameters.

cast_model_to_dtype

Cast model parameters to the target dtype, keeping fp32 modules in full precision.

_get_fp32_module_keywords

Collect module name patterns that must remain in fp32.

_has_dtensor_params

Check if any model parameter is a DTensor (FSDP2 sharded).

_restore_fp32_modules

Cast modules matching fp32_keywords back to float32.

_restore_fp32_buffers

Cast only buffers (not parameters) of matching modules back to float32.

Data#

API#

nemo_automodel.components.models.common.utils.logger#

‘getLogger(…)’

nemo_automodel.components.models.common.utils.HAVE_TE#

None

nemo_automodel.components.models.common.utils.HAVE_DEEP_EP#

None

nemo_automodel.components.models.common.utils.HAVE_GMM#

None

nemo_automodel.components.models.common.utils.IS_OPTIM_STEP#

False

nemo_automodel.components.models.common.utils.IS_FIRST_MICROBATCH: bool | None#

None

nemo_automodel.components.models.common.utils.set_is_optim_step(value: bool) None#

Set the global IS_OPTIM_STEP flag.

Parameters:

value – Whether we are in an optimization step.

nemo_automodel.components.models.common.utils.get_is_optim_step() bool#

Get the global IS_OPTIM_STEP flag.

Returns:

Whether we are in an optimization step.

nemo_automodel.components.models.common.utils.set_is_first_microbatch(value: bool | None) None#

Set the global IS_FIRST_MICROBATCH flag for FP8 weight caching.

Parameters:

value – True for first microbatch (quantize+cache), False for subsequent (use cached), None to disable caching.

nemo_automodel.components.models.common.utils.get_is_first_microbatch() bool | None#

Get the global IS_FIRST_MICROBATCH flag.

Returns:

True/False/None indicating microbatch position for FP8 weight caching.

nemo_automodel.components.models.common.utils.is_tensor_unallocated(tensor: torch.Tensor) bool#

Check if tensor is unallocated (meta tensor, fake tensor, etc.).

TE kernels don’t support meta tensors, fake tensors, or unallocated tensors. This helper detects such cases for fallback handling.

Parameters:

tensor – Tensor to check

Returns:

True if tensor is unallocated or cannot be accessed

class nemo_automodel.components.models.common.utils.TEFp8Config#

Configuration for Transformer Engine FP8 quantization.

When present (not None) in BackendConfig, FP8 is enabled. The recipe field accepts either a string shorthand ("current" or "block") or a pre-built TE recipe object (e.g. Float8CurrentScaling(fp8_dpa=True)).

recipe: Literal[current, block] | Any#

‘current’

build_recipe()#

Build and return the TE FP8 recipe object.

If recipe is already a TE recipe object (e.g. Float8CurrentScaling(...)), it is returned directly. String values "current" and "block" are mapped to the corresponding TE recipe class.

maybe_te_autocast()#

Return te_autocast context manager for FP8.

class nemo_automodel.components.models.common.utils.BackendConfig#

Backend configuration for model components.

.. attribute:: attn

Attention backend (“te”, “sdpa”, or “flex”).

.. attribute:: linear

Linear layer backend (“torch” or “te”).

.. attribute:: rms_norm

RMSNorm backend (“torch”, “torch_fp32”, or “te”).

.. attribute:: rope_fusion

Whether to use fused RoPE (requires TE).

.. attribute:: experts

MoE expert GEMM backend. “torch” uses per-expert loop, “te” uses TE GroupedLinear, “gmm” uses grouped_gemm.ops.gmm, “torch_mm” uses torch._grouped_mm.

.. attribute:: dispatcher

MoE token dispatcher. “torch” uses DTensor all-gather/reduce-scatter, “deepep” uses DeepEP for token dispatch.

.. attribute:: enable_deepep

Deprecated. Use dispatcher=”deepep” and experts=”gmm” instead.

.. attribute:: fake_balanced_gate

If True, replace the learned Gate with FakeBalancedGate that assigns tokens to experts without learned routing weights.

.. attribute:: fake_gate_noise

Noise level [0, 1] for FakeBalancedGate. When > 0, uses biased topk selection seeded from the input content so routing varies dynamically across training steps (like real Gate) while remaining deterministic for activation checkpointing recompute (same input = same routing). Only used when fake_balanced_gate=True.

.. attribute:: enable_hf_state_dict_adapter

Whether to enable HuggingFace state dict adapter.

.. attribute:: enable_fsdp_optimizations

Whether to enable FSDP2 optimizations.

.. attribute:: gate_precision

Optional dtype override for the gate computation. Accepts torch.dtype or string (e.g., “torch.float32”, “float32”).

attn: Literal[te, sdpa, flex]#

None

linear: Literal[torch, te]#

None

rms_norm: Literal[torch, torch_fp32, te]#

‘torch_fp32’

rope_fusion: bool#

None

experts: Literal[torch, te, gmm, torch_mm]#

None

dispatcher: Literal[torch, deepep]#

None

enable_deepep: bool | None#

None

fake_balanced_gate: bool#

False

fake_gate_noise: float#

0.0

enable_hf_state_dict_adapter: bool#

True

enable_fsdp_optimizations: bool#

False

te_fp8: nemo_automodel.components.models.common.utils.TEFp8Config | None#

None

gate_precision: str | torch.dtype | None#

None

__post_init__()#
class nemo_automodel.components.models.common.utils.Float32RMSNorm(dim, eps=1e-05, device=None, dtype=torch.bfloat16)#

Bases: torch.nn.Module

RMSNorm with explicit fp32 computation for training stability.

Weights stay in the model’s dtype (e.g. bf16) for FSDP2 compatibility. Inputs are upcast to fp32, norm is computed in fp32, and the output is cast back to the original input dtype.

Initialization

reset_parameters()#
forward(x)#
nemo_automodel.components.models.common.utils.initialize_rms_norm_module(
rms_norm_impl: str,
dim: int,
eps: float = 1e-05,
device: torch.device | str | None = None,
dtype: torch.dtype = torch.bfloat16,
) torch.nn.Module#

Initialize RMSNorm module with the specified backend.

For TE backend, creates TE module directly on specified device. Call reset_parameters() to materialize weights if created on meta device.

Parameters:
  • rms_norm_impl

    Backend implementation (“te”, “torch”, or “torch_fp32”)

    • ”te”: Transformer Engine fused RMSNorm kernel

    • ”torch”: PyTorch native nn.RMSNorm (computes in input dtype)

    • ”torch_fp32”: torch.compiled fp32 RMSNorm for training stability

  • dim – Normalized dimension

  • eps – Epsilon for numerical stability

  • device – Device to create module on (None uses PyTorch default, typically CPU)

  • dtype – Parameter dtype

Returns:

RMSNorm module

nemo_automodel.components.models.common.utils.initialize_linear_module(
linear_impl: str,
in_features: int,
out_features: int,
bias: bool = False,
device: torch.device | str | None = None,
dtype: torch.dtype = torch.bfloat16,
) torch.nn.Module#

Initialize Linear module with the specified backend.

For TE backend, creates TE module directly on specified device. Call reset_parameters() to materialize weights if created on meta device.

Parameters:
  • linear_impl – Backend implementation (“te” or “torch”)

  • in_features – Input features

  • out_features – Output features

  • bias – Whether to use bias

  • device – Device to create module on (None uses PyTorch default, typically CPU)

  • dtype – Parameter dtype

Returns:

Linear module

nemo_automodel.components.models.common.utils._make_lazy_te_patcher()#

Return a callable that patches TE modules exactly once.

Uses a closure instead of module-level global state to track whether the patch has already been applied. The actual transformer_engine import is deferred until the first call so that importing this module never triggers heavy native-library loads (flash-attn, CUDA kernels, etc.).

Two patches are applied:

  1. Unallocated tensor handling: TE kernels don’t support meta/fake tensors, so we short-circuit with empty tensors for PP shape inference.

  2. is_first_microbatch injection: Reads the global IS_FIRST_MICROBATCH flag and passes it to TE Linear/GroupedLinear for FP8 weight caching during gradient accumulation (quantize on first microbatch, reuse cached on rest).

nemo_automodel.components.models.common.utils._patch_te_modules#

‘_make_lazy_te_patcher(…)’

nemo_automodel.components.models.common.utils.get_rope_config(config) tuple[float, dict, float]#

Extract rope configuration from config.rope_parameters.

Parameters:

config – A HuggingFace model config object.

Returns:

Tuple of (rope_theta, rope_parameters, partial_rotary_factor).

nemo_automodel.components.models.common.utils.cast_model_to_dtype(
model: torch.nn.Module,
dtype: torch.dtype = torch.bfloat16,
) None#

Cast model parameters to the target dtype, keeping fp32 modules in full precision.

Respects _keep_in_fp32_modules / _keep_in_fp32_modules_strict on the model (the same attributes HuggingFace transformers uses).

Uses nn.Module.to() which is safe for both plain tensors and DTensors (FSDP2 sharded parameters). When the model is already FSDP2-sharded (parameters are DTensors), only buffers of matching modules are restored to fp32 (parameters are left as-is since FSDP2 requires uniform dtype).

Parameters:
  • model – The model whose parameters should be cast.

  • dtype – Target dtype (e.g. torch.bfloat16).

nemo_automodel.components.models.common.utils._get_fp32_module_keywords(model: torch.nn.Module) list[str]#

Collect module name patterns that must remain in fp32.

Reads _keep_in_fp32_modules and _keep_in_fp32_modules_strict from the model (the same attributes HuggingFace transformers uses).

Parameters:

model – The model to inspect.

Returns:

De-duplicated list of module-name keywords to keep in fp32.

nemo_automodel.components.models.common.utils._has_dtensor_params(model: torch.nn.Module) bool#

Check if any model parameter is a DTensor (FSDP2 sharded).

nemo_automodel.components.models.common.utils._restore_fp32_modules(
model: torch.nn.Module,
fp32_keywords: list[str],
) None#

Cast modules matching fp32_keywords back to float32.

Only safe for unsharded models (plain tensors). FSDP2 requires uniform dtype within each parameter group, so this must not be called on DTensor-sharded models.

Parameters:
  • model – The model (already cast to the target dtype).

  • fp32_keywords – Substrings matched against dot-separated module names.

nemo_automodel.components.models.common.utils._restore_fp32_buffers(
model: torch.nn.Module,
fp32_keywords: list[str],
) None#

Cast only buffers (not parameters) of matching modules back to float32.

Safe for FSDP2-sharded models because buffers are plain tensors, not DTensors managed by FSDP2.

Parameters:
  • model – The model (already cast to the target dtype).

  • fp32_keywords – Substrings matched against dot-separated module names.

nemo_automodel.components.models.common.utils.__all__#

[‘BackendConfig’, ‘Float32RMSNorm’, ‘TEFp8Config’, ‘cast_model_to_dtype’, ‘get_is_first_microbatch’,…