nemo_automodel.components.models.common.utils#

Module Contents#

Classes#

TEFp8Config

Configuration for Transformer Engine FP8 quantization.

BackendConfig

Backend configuration for model components.

Functions#

set_is_optim_step

Set the global IS_OPTIM_STEP flag.

get_is_optim_step

Get the global IS_OPTIM_STEP flag.

set_is_first_microbatch

Set the global IS_FIRST_MICROBATCH flag for FP8 weight caching.

get_is_first_microbatch

Get the global IS_FIRST_MICROBATCH flag.

is_tensor_unallocated

Check if tensor is unallocated (meta tensor, fake tensor, etc.).

initialize_rms_norm_module

Initialize RMSNorm module with the specified backend.

initialize_linear_module

Initialize Linear module with the specified backend.

_make_lazy_te_patcher

Return a callable that patches TE modules exactly once.

Data#

API#

nemo_automodel.components.models.common.utils.logger#

‘getLogger(…)’

nemo_automodel.components.models.common.utils.HAVE_TE#

None

nemo_automodel.components.models.common.utils.HAVE_DEEP_EP#

None

nemo_automodel.components.models.common.utils.HAVE_GMM#

None

nemo_automodel.components.models.common.utils.IS_OPTIM_STEP#

False

nemo_automodel.components.models.common.utils.IS_FIRST_MICROBATCH: bool | None#

None

nemo_automodel.components.models.common.utils.set_is_optim_step(value: bool) None#

Set the global IS_OPTIM_STEP flag.

Parameters:

value – Whether we are in an optimization step.

nemo_automodel.components.models.common.utils.get_is_optim_step() bool#

Get the global IS_OPTIM_STEP flag.

Returns:

Whether we are in an optimization step.

nemo_automodel.components.models.common.utils.set_is_first_microbatch(value: bool | None) None#

Set the global IS_FIRST_MICROBATCH flag for FP8 weight caching.

Parameters:

value – True for first microbatch (quantize+cache), False for subsequent (use cached), None to disable caching.

nemo_automodel.components.models.common.utils.get_is_first_microbatch() bool | None#

Get the global IS_FIRST_MICROBATCH flag.

Returns:

True/False/None indicating microbatch position for FP8 weight caching.

nemo_automodel.components.models.common.utils.is_tensor_unallocated(tensor: torch.Tensor) bool#

Check if tensor is unallocated (meta tensor, fake tensor, etc.).

TE kernels don’t support meta tensors, fake tensors, or unallocated tensors. This helper detects such cases for fallback handling.

Parameters:

tensor – Tensor to check

Returns:

True if tensor is unallocated or cannot be accessed

class nemo_automodel.components.models.common.utils.TEFp8Config#

Configuration for Transformer Engine FP8 quantization.

When present (not None) in BackendConfig, FP8 is enabled. The recipe field accepts either a string shorthand ("current" or "block") or a pre-built TE recipe object (e.g. Float8CurrentScaling(fp8_dpa=True)).

recipe: Literal[current, block] | Any#

‘current’

build_recipe()#

Build and return the TE FP8 recipe object.

If recipe is already a TE recipe object (e.g. Float8CurrentScaling(...)), it is returned directly. String values "current" and "block" are mapped to the corresponding TE recipe class.

maybe_te_autocast()#

Return te_autocast context manager for FP8.

class nemo_automodel.components.models.common.utils.BackendConfig#

Backend configuration for model components.

.. attribute:: attn

Attention backend (“te”, “sdpa”, or “flex”).

.. attribute:: linear

Linear layer backend (“torch” or “te”).

.. attribute:: rms_norm

RMSNorm backend (“torch” or “te”).

.. attribute:: rope_fusion

Whether to use fused RoPE (requires TE).

.. attribute:: experts

MoE expert GEMM backend. “torch” uses per-expert loop, “te” uses TE GroupedLinear, “gmm” uses grouped_gemm.ops.gmm, “torch_mm” uses torch._grouped_mm.

.. attribute:: dispatcher

MoE token dispatcher. “torch” uses DTensor all-gather/reduce-scatter, “deepep” uses DeepEP for token dispatch.

.. attribute:: enable_deepep

Deprecated. Use dispatcher=”deepep” and experts=”gmm” instead.

.. attribute:: fake_balanced_gate

If True, replace the learned Gate with FakeBalancedGate that assigns tokens to experts without learned routing weights.

.. attribute:: fake_gate_noise

Noise level [0, 1] for FakeBalancedGate. When > 0, uses biased topk selection seeded from the input content so routing varies dynamically across training steps (like real Gate) while remaining deterministic for activation checkpointing recompute (same input = same routing). Only used when fake_balanced_gate=True.

.. attribute:: enable_hf_state_dict_adapter

Whether to enable HuggingFace state dict adapter.

.. attribute:: enable_fsdp_optimizations

Whether to enable FSDP2 optimizations.

.. attribute:: gate_precision

Optional dtype override for the gate computation. Accepts torch.dtype or string (e.g., “torch.float32”, “float32”).

attn: Literal[te, sdpa, flex]#

None

linear: Literal[torch, te]#

None

rms_norm: Literal[torch, torch_fp32, te]#

None

rope_fusion: bool#

None

experts: Literal[torch, te, gmm, torch_mm]#

None

dispatcher: Literal[torch, deepep]#

None

enable_deepep: bool | None#

None

fake_balanced_gate: bool#

False

fake_gate_noise: float#

0.0

enable_hf_state_dict_adapter: bool#

True

enable_fsdp_optimizations: bool#

False

te_fp8: nemo_automodel.components.models.common.utils.TEFp8Config | None#

None

gate_precision: str | torch.dtype | None#

None

__post_init__()#
nemo_automodel.components.models.common.utils.initialize_rms_norm_module(
rms_norm_impl: str,
dim: int,
eps: float = 1e-05,
device: torch.device | str | None = None,
dtype: torch.dtype = torch.bfloat16,
) torch.nn.Module#

Initialize RMSNorm module with the specified backend.

For TE backend, creates TE module directly on specified device. Call reset_parameters() to materialize weights if created on meta device.

Parameters:
  • rms_norm_impl

    Backend implementation (“te”, “torch”, or “torch_fp32”)

    • ”te”: Transformer Engine fused RMSNorm kernel

    • ”torch”: PyTorch native nn.RMSNorm (computes in input dtype)

    • ”torch_fp32”: Float32 input upcast RMSNorm

  • dim – Normalized dimension

  • eps – Epsilon for numerical stability

  • device – Device to create module on (None uses PyTorch default, typically CPU)

  • dtype – Parameter dtype

Returns:

RMSNorm module

nemo_automodel.components.models.common.utils.initialize_linear_module(
linear_impl: str,
in_features: int,
out_features: int,
bias: bool = False,
device: torch.device | str | None = None,
dtype: torch.dtype = torch.bfloat16,
) torch.nn.Module#

Initialize Linear module with the specified backend.

For TE backend, creates TE module directly on specified device. Call reset_parameters() to materialize weights if created on meta device.

Parameters:
  • linear_impl – Backend implementation (“te” or “torch”)

  • in_features – Input features

  • out_features – Output features

  • bias – Whether to use bias

  • device – Device to create module on (None uses PyTorch default, typically CPU)

  • dtype – Parameter dtype

Returns:

Linear module

nemo_automodel.components.models.common.utils._make_lazy_te_patcher()#

Return a callable that patches TE modules exactly once.

Uses a closure instead of module-level global state to track whether the patch has already been applied. The actual transformer_engine import is deferred until the first call so that importing this module never triggers heavy native-library loads (flash-attn, CUDA kernels, etc.).

Two patches are applied:

  1. Unallocated tensor handling: TE kernels don’t support meta/fake tensors, so we short-circuit with empty tensors for PP shape inference.

  2. is_first_microbatch injection: Reads the global IS_FIRST_MICROBATCH flag and passes it to TE Linear/GroupedLinear for FP8 weight caching during gradient accumulation (quantize on first microbatch, reuse cached on rest).

nemo_automodel.components.models.common.utils._patch_te_modules#

‘_make_lazy_te_patcher(…)’

nemo_automodel.components.models.common.utils.__all__#

[‘BackendConfig’, ‘TEFp8Config’, ‘get_is_first_microbatch’, ‘get_is_optim_step’, ‘initialize_linear_…