nemo_automodel.shared.te_patches

View as Markdown

Transformer Engine compatibility patches.

Runtime monkey-patches applied directly to TE classes in memory so they take effect immediately in the current process.

Call apply_te_patches() early in the process, before TE optimizers are instantiated.

Module Contents

Functions

NameDescription
_apply_fused_adam_quantized_tensor_patchPatch FusedAdam._initialize_state to handle QuantizedTensor params.
apply_te_patchesApply all Transformer Engine runtime patches.

Data

_TE_PATCHES_APPLIED

_logger

API

nemo_automodel.shared.te_patches._apply_fused_adam_quantized_tensor_patch() -> None

Patch FusedAdam._initialize_state to handle QuantizedTensor params.

TE’s FusedAdam uses torch.zeros(param.shape, ...) / torch.empty(param.shape, ...) in _initialize_state, which fails for QuantizedTensor parameters because their .shape does not carry the correct metadata for allocation. The fix dequantizes the param first and uses torch.zeros_like / torch.empty_like instead.

The fix was merged upstream in TE 2.12 via https://github.com/NVIDIA/TransformerEngine/pull/2535.

nemo_automodel.shared.te_patches.apply_te_patches() -> None

Apply all Transformer Engine runtime patches.

This function is idempotent and safe to call multiple times.

nemo_automodel.shared.te_patches._TE_PATCHES_APPLIED = False
nemo_automodel.shared.te_patches._logger = logging.getLogger(__name__)