nemo_automodel.components.models.common.utils#
Module Contents#
Classes#
Configuration for Transformer Engine FP8 quantization. |
|
Backend configuration for model components. |
Functions#
Set the global IS_OPTIM_STEP flag. |
|
Get the global IS_OPTIM_STEP flag. |
|
Set the global IS_FIRST_MICROBATCH flag for FP8 weight caching. |
|
Get the global IS_FIRST_MICROBATCH flag. |
|
Check if tensor is unallocated (meta tensor, fake tensor, etc.). |
|
Initialize RMSNorm module with the specified backend. |
|
Initialize Linear module with the specified backend. |
|
Return a callable that patches TE modules exactly once. |
Data#
API#
- nemo_automodel.components.models.common.utils.logger#
‘getLogger(…)’
- nemo_automodel.components.models.common.utils.HAVE_TE#
None
- nemo_automodel.components.models.common.utils.HAVE_DEEP_EP#
None
- nemo_automodel.components.models.common.utils.HAVE_GMM#
None
- nemo_automodel.components.models.common.utils.IS_OPTIM_STEP#
False
- nemo_automodel.components.models.common.utils.IS_FIRST_MICROBATCH: bool | None#
None
- nemo_automodel.components.models.common.utils.set_is_optim_step(value: bool) None#
Set the global IS_OPTIM_STEP flag.
- Parameters:
value – Whether we are in an optimization step.
- nemo_automodel.components.models.common.utils.get_is_optim_step() bool#
Get the global IS_OPTIM_STEP flag.
- Returns:
Whether we are in an optimization step.
- nemo_automodel.components.models.common.utils.set_is_first_microbatch(value: bool | None) None#
Set the global IS_FIRST_MICROBATCH flag for FP8 weight caching.
- Parameters:
value – True for first microbatch (quantize+cache), False for subsequent (use cached), None to disable caching.
- nemo_automodel.components.models.common.utils.get_is_first_microbatch() bool | None#
Get the global IS_FIRST_MICROBATCH flag.
- Returns:
True/False/None indicating microbatch position for FP8 weight caching.
- nemo_automodel.components.models.common.utils.is_tensor_unallocated(tensor: torch.Tensor) bool#
Check if tensor is unallocated (meta tensor, fake tensor, etc.).
TE kernels don’t support meta tensors, fake tensors, or unallocated tensors. This helper detects such cases for fallback handling.
- Parameters:
tensor – Tensor to check
- Returns:
True if tensor is unallocated or cannot be accessed
- class nemo_automodel.components.models.common.utils.TEFp8Config#
Configuration for Transformer Engine FP8 quantization.
When present (not None) in BackendConfig, FP8 is enabled. The
recipefield accepts either a string shorthand ("current"or"block") or a pre-built TE recipe object (e.g.Float8CurrentScaling(fp8_dpa=True)).- recipe: Literal[current, block] | Any#
‘current’
- build_recipe()#
Build and return the TE FP8 recipe object.
If
recipeis already a TE recipe object (e.g.Float8CurrentScaling(...)), it is returned directly. String values"current"and"block"are mapped to the corresponding TE recipe class.
- maybe_te_autocast()#
Return te_autocast context manager for FP8.
- class nemo_automodel.components.models.common.utils.BackendConfig#
Backend configuration for model components.
.. attribute:: attn
Attention backend (“te”, “sdpa”, or “flex”).
.. attribute:: linear
Linear layer backend (“torch” or “te”).
.. attribute:: rms_norm
RMSNorm backend (“torch” or “te”).
.. attribute:: rope_fusion
Whether to use fused RoPE (requires TE).
.. attribute:: experts
MoE expert GEMM backend. “torch” uses per-expert loop, “te” uses TE GroupedLinear, “gmm” uses grouped_gemm.ops.gmm, “torch_mm” uses torch._grouped_mm.
.. attribute:: dispatcher
MoE token dispatcher. “torch” uses DTensor all-gather/reduce-scatter, “deepep” uses DeepEP for token dispatch.
.. attribute:: enable_deepep
Deprecated. Use dispatcher=”deepep” and experts=”gmm” instead.
.. attribute:: fake_balanced_gate
If True, replace the learned Gate with FakeBalancedGate that assigns tokens to experts without learned routing weights.
.. attribute:: fake_gate_noise
Noise level [0, 1] for FakeBalancedGate. When > 0, uses biased topk selection seeded from the input content so routing varies dynamically across training steps (like real Gate) while remaining deterministic for activation checkpointing recompute (same input = same routing). Only used when fake_balanced_gate=True.
.. attribute:: enable_hf_state_dict_adapter
Whether to enable HuggingFace state dict adapter.
.. attribute:: enable_fsdp_optimizations
Whether to enable FSDP2 optimizations.
.. attribute:: gate_precision
Optional dtype override for the gate computation. Accepts torch.dtype or string (e.g., “torch.float32”, “float32”).
- attn: Literal[te, sdpa, flex]#
None
- linear: Literal[torch, te]#
None
- rms_norm: Literal[torch, torch_fp32, te]#
None
- rope_fusion: bool#
None
- experts: Literal[torch, te, gmm, torch_mm]#
None
- dispatcher: Literal[torch, deepep]#
None
- enable_deepep: bool | None#
None
- fake_balanced_gate: bool#
False
- fake_gate_noise: float#
0.0
- enable_hf_state_dict_adapter: bool#
True
- enable_fsdp_optimizations: bool#
False
- te_fp8: nemo_automodel.components.models.common.utils.TEFp8Config | None#
None
- gate_precision: str | torch.dtype | None#
None
- __post_init__()#
- nemo_automodel.components.models.common.utils.initialize_rms_norm_module(
- rms_norm_impl: str,
- dim: int,
- eps: float = 1e-05,
- device: torch.device | str | None = None,
- dtype: torch.dtype = torch.bfloat16,
Initialize RMSNorm module with the specified backend.
For TE backend, creates TE module directly on specified device. Call reset_parameters() to materialize weights if created on meta device.
- Parameters:
rms_norm_impl –
Backend implementation (“te”, “torch”, or “torch_fp32”)
”te”: Transformer Engine fused RMSNorm kernel
”torch”: PyTorch native nn.RMSNorm (computes in input dtype)
”torch_fp32”: Float32 input upcast RMSNorm
dim – Normalized dimension
eps – Epsilon for numerical stability
device – Device to create module on (None uses PyTorch default, typically CPU)
dtype – Parameter dtype
- Returns:
RMSNorm module
- nemo_automodel.components.models.common.utils.initialize_linear_module(
- linear_impl: str,
- in_features: int,
- out_features: int,
- bias: bool = False,
- device: torch.device | str | None = None,
- dtype: torch.dtype = torch.bfloat16,
Initialize Linear module with the specified backend.
For TE backend, creates TE module directly on specified device. Call reset_parameters() to materialize weights if created on meta device.
- Parameters:
linear_impl – Backend implementation (“te” or “torch”)
in_features – Input features
out_features – Output features
bias – Whether to use bias
device – Device to create module on (None uses PyTorch default, typically CPU)
dtype – Parameter dtype
- Returns:
Linear module
- nemo_automodel.components.models.common.utils._make_lazy_te_patcher()#
Return a callable that patches TE modules exactly once.
Uses a closure instead of module-level global state to track whether the patch has already been applied. The actual
transformer_engineimport is deferred until the first call so that importing this module never triggers heavy native-library loads (flash-attn, CUDA kernels, etc.).Two patches are applied:
Unallocated tensor handling: TE kernels don’t support meta/fake tensors, so we short-circuit with empty tensors for PP shape inference.
is_first_microbatch injection: Reads the global IS_FIRST_MICROBATCH flag and passes it to TE Linear/GroupedLinear for FP8 weight caching during gradient accumulation (quantize on first microbatch, reuse cached on rest).
- nemo_automodel.components.models.common.utils._patch_te_modules#
‘_make_lazy_te_patcher(…)’
- nemo_automodel.components.models.common.utils.__all__#
[‘BackendConfig’, ‘TEFp8Config’, ‘get_is_first_microbatch’, ‘get_is_optim_step’, ‘initialize_linear_…