nemo_automodel.components.models.common.utils#
Module Contents#
Classes#
Configuration for Transformer Engine FP8 quantization. |
|
Backend configuration for model components. |
|
RMSNorm with explicit fp32 computation for training stability. |
Functions#
Set the global IS_OPTIM_STEP flag. |
|
Get the global IS_OPTIM_STEP flag. |
|
Set the global IS_FIRST_MICROBATCH flag for FP8 weight caching. |
|
Get the global IS_FIRST_MICROBATCH flag. |
|
Check if tensor is unallocated (meta tensor, fake tensor, etc.). |
|
Initialize RMSNorm module with the specified backend. |
|
Initialize Linear module with the specified backend. |
|
Return a callable that patches TE modules exactly once. |
|
Extract rope configuration from |
|
Cast model parameters to the target dtype, keeping fp32 modules in full precision. |
|
Collect module name patterns that must remain in fp32. |
|
Check if any model parameter is a DTensor (FSDP2 sharded). |
|
Cast modules matching fp32_keywords back to float32. |
|
Cast only buffers (not parameters) of matching modules back to float32. |
Data#
API#
- nemo_automodel.components.models.common.utils.logger#
‘getLogger(…)’
- nemo_automodel.components.models.common.utils.HAVE_TE#
None
- nemo_automodel.components.models.common.utils.HAVE_DEEP_EP#
None
- nemo_automodel.components.models.common.utils.HAVE_GMM#
None
- nemo_automodel.components.models.common.utils.IS_OPTIM_STEP#
False
- nemo_automodel.components.models.common.utils.IS_FIRST_MICROBATCH: bool | None#
None
- nemo_automodel.components.models.common.utils.set_is_optim_step(value: bool) None#
Set the global IS_OPTIM_STEP flag.
- Parameters:
value – Whether we are in an optimization step.
- nemo_automodel.components.models.common.utils.get_is_optim_step() bool#
Get the global IS_OPTIM_STEP flag.
- Returns:
Whether we are in an optimization step.
- nemo_automodel.components.models.common.utils.set_is_first_microbatch(value: bool | None) None#
Set the global IS_FIRST_MICROBATCH flag for FP8 weight caching.
- Parameters:
value – True for first microbatch (quantize+cache), False for subsequent (use cached), None to disable caching.
- nemo_automodel.components.models.common.utils.get_is_first_microbatch() bool | None#
Get the global IS_FIRST_MICROBATCH flag.
- Returns:
True/False/None indicating microbatch position for FP8 weight caching.
- nemo_automodel.components.models.common.utils.is_tensor_unallocated(tensor: torch.Tensor) bool#
Check if tensor is unallocated (meta tensor, fake tensor, etc.).
TE kernels don’t support meta tensors, fake tensors, or unallocated tensors. This helper detects such cases for fallback handling.
- Parameters:
tensor – Tensor to check
- Returns:
True if tensor is unallocated or cannot be accessed
- class nemo_automodel.components.models.common.utils.TEFp8Config#
Configuration for Transformer Engine FP8 quantization.
When present (not None) in BackendConfig, FP8 is enabled. The
recipefield accepts either a string shorthand ("current"or"block") or a pre-built TE recipe object (e.g.Float8CurrentScaling(fp8_dpa=True)).- recipe: Literal[current, block] | Any#
‘current’
- build_recipe()#
Build and return the TE FP8 recipe object.
If
recipeis already a TE recipe object (e.g.Float8CurrentScaling(...)), it is returned directly. String values"current"and"block"are mapped to the corresponding TE recipe class.
- maybe_te_autocast()#
Return te_autocast context manager for FP8.
- class nemo_automodel.components.models.common.utils.BackendConfig#
Backend configuration for model components.
.. attribute:: attn
Attention backend (“te”, “sdpa”, or “flex”).
.. attribute:: linear
Linear layer backend (“torch” or “te”).
.. attribute:: rms_norm
RMSNorm backend (“torch”, “torch_fp32”, or “te”).
.. attribute:: rope_fusion
Whether to use fused RoPE (requires TE).
.. attribute:: experts
MoE expert GEMM backend. “torch” uses per-expert loop, “te” uses TE GroupedLinear, “gmm” uses grouped_gemm.ops.gmm, “torch_mm” uses torch._grouped_mm.
.. attribute:: dispatcher
MoE token dispatcher. “torch” uses DTensor all-gather/reduce-scatter, “deepep” uses DeepEP for token dispatch.
.. attribute:: enable_deepep
Deprecated. Use dispatcher=”deepep” and experts=”gmm” instead.
.. attribute:: fake_balanced_gate
If True, replace the learned Gate with FakeBalancedGate that assigns tokens to experts without learned routing weights.
.. attribute:: fake_gate_noise
Noise level [0, 1] for FakeBalancedGate. When > 0, uses biased topk selection seeded from the input content so routing varies dynamically across training steps (like real Gate) while remaining deterministic for activation checkpointing recompute (same input = same routing). Only used when fake_balanced_gate=True.
.. attribute:: enable_hf_state_dict_adapter
Whether to enable HuggingFace state dict adapter.
.. attribute:: enable_fsdp_optimizations
Whether to enable FSDP2 optimizations.
.. attribute:: gate_precision
Optional dtype override for the gate computation. Accepts torch.dtype or string (e.g., “torch.float32”, “float32”).
- attn: Literal[te, sdpa, flex]#
None
- linear: Literal[torch, te]#
None
- rms_norm: Literal[torch, torch_fp32, te]#
‘torch_fp32’
- rope_fusion: bool#
None
- experts: Literal[torch, te, gmm, torch_mm]#
None
- dispatcher: Literal[torch, deepep]#
None
- enable_deepep: bool | None#
None
- fake_balanced_gate: bool#
False
- fake_gate_noise: float#
0.0
- enable_hf_state_dict_adapter: bool#
True
- enable_fsdp_optimizations: bool#
False
- te_fp8: nemo_automodel.components.models.common.utils.TEFp8Config | None#
None
- gate_precision: str | torch.dtype | None#
None
- __post_init__()#
- class nemo_automodel.components.models.common.utils.Float32RMSNorm(dim, eps=1e-05, device=None, dtype=torch.bfloat16)#
Bases:
torch.nn.ModuleRMSNorm with explicit fp32 computation for training stability.
Weights stay in the model’s dtype (e.g. bf16) for FSDP2 compatibility. Inputs are upcast to fp32, norm is computed in fp32, and the output is cast back to the original input dtype.
Initialization
- reset_parameters()#
- forward(x)#
- nemo_automodel.components.models.common.utils.initialize_rms_norm_module(
- rms_norm_impl: str,
- dim: int,
- eps: float = 1e-05,
- device: torch.device | str | None = None,
- dtype: torch.dtype = torch.bfloat16,
Initialize RMSNorm module with the specified backend.
For TE backend, creates TE module directly on specified device. Call reset_parameters() to materialize weights if created on meta device.
- Parameters:
rms_norm_impl –
Backend implementation (“te”, “torch”, or “torch_fp32”)
”te”: Transformer Engine fused RMSNorm kernel
”torch”: PyTorch native nn.RMSNorm (computes in input dtype)
”torch_fp32”: torch.compiled fp32 RMSNorm for training stability
dim – Normalized dimension
eps – Epsilon for numerical stability
device – Device to create module on (None uses PyTorch default, typically CPU)
dtype – Parameter dtype
- Returns:
RMSNorm module
- nemo_automodel.components.models.common.utils.initialize_linear_module(
- linear_impl: str,
- in_features: int,
- out_features: int,
- bias: bool = False,
- device: torch.device | str | None = None,
- dtype: torch.dtype = torch.bfloat16,
Initialize Linear module with the specified backend.
For TE backend, creates TE module directly on specified device. Call reset_parameters() to materialize weights if created on meta device.
- Parameters:
linear_impl – Backend implementation (“te” or “torch”)
in_features – Input features
out_features – Output features
bias – Whether to use bias
device – Device to create module on (None uses PyTorch default, typically CPU)
dtype – Parameter dtype
- Returns:
Linear module
- nemo_automodel.components.models.common.utils._make_lazy_te_patcher()#
Return a callable that patches TE modules exactly once.
Uses a closure instead of module-level global state to track whether the patch has already been applied. The actual
transformer_engineimport is deferred until the first call so that importing this module never triggers heavy native-library loads (flash-attn, CUDA kernels, etc.).Two patches are applied:
Unallocated tensor handling: TE kernels don’t support meta/fake tensors, so we short-circuit with empty tensors for PP shape inference.
is_first_microbatch injection: Reads the global IS_FIRST_MICROBATCH flag and passes it to TE Linear/GroupedLinear for FP8 weight caching during gradient accumulation (quantize on first microbatch, reuse cached on rest).
- nemo_automodel.components.models.common.utils._patch_te_modules#
‘_make_lazy_te_patcher(…)’
- nemo_automodel.components.models.common.utils.get_rope_config(config) tuple[float, dict, float]#
Extract rope configuration from
config.rope_parameters.- Parameters:
config – A HuggingFace model config object.
- Returns:
Tuple of (rope_theta, rope_parameters, partial_rotary_factor).
- nemo_automodel.components.models.common.utils.cast_model_to_dtype(
- model: torch.nn.Module,
- dtype: torch.dtype = torch.bfloat16,
Cast model parameters to the target dtype, keeping fp32 modules in full precision.
Respects
_keep_in_fp32_modules/_keep_in_fp32_modules_stricton the model (the same attributes HuggingFace transformers uses).Uses
nn.Module.to()which is safe for both plain tensors and DTensors (FSDP2 sharded parameters). When the model is already FSDP2-sharded (parameters are DTensors), only buffers of matching modules are restored to fp32 (parameters are left as-is since FSDP2 requires uniform dtype).- Parameters:
model – The model whose parameters should be cast.
dtype – Target dtype (e.g.
torch.bfloat16).
- nemo_automodel.components.models.common.utils._get_fp32_module_keywords(model: torch.nn.Module) list[str]#
Collect module name patterns that must remain in fp32.
Reads
_keep_in_fp32_modulesand_keep_in_fp32_modules_strictfrom the model (the same attributes HuggingFace transformers uses).- Parameters:
model – The model to inspect.
- Returns:
De-duplicated list of module-name keywords to keep in fp32.
- nemo_automodel.components.models.common.utils._has_dtensor_params(model: torch.nn.Module) bool#
Check if any model parameter is a DTensor (FSDP2 sharded).
- nemo_automodel.components.models.common.utils._restore_fp32_modules(
- model: torch.nn.Module,
- fp32_keywords: list[str],
Cast modules matching fp32_keywords back to float32.
Only safe for unsharded models (plain tensors). FSDP2 requires uniform dtype within each parameter group, so this must not be called on DTensor-sharded models.
- Parameters:
model – The model (already cast to the target dtype).
fp32_keywords – Substrings matched against dot-separated module names.
- nemo_automodel.components.models.common.utils._restore_fp32_buffers(
- model: torch.nn.Module,
- fp32_keywords: list[str],
Cast only buffers (not parameters) of matching modules back to float32.
Safe for FSDP2-sharded models because buffers are plain tensors, not DTensors managed by FSDP2.
- Parameters:
model – The model (already cast to the target dtype).
fp32_keywords – Substrings matched against dot-separated module names.
- nemo_automodel.components.models.common.utils.__all__#
[‘BackendConfig’, ‘Float32RMSNorm’, ‘TEFp8Config’, ‘cast_model_to_dtype’, ‘get_is_first_microbatch’,…