nemo_automodel.shared.transformers_patches

View as Markdown

Transformers compatibility patches.

Runtime monkey-patch for apex’s FusedRMSNorm which does not support bfloat16. Call patch_t5_layer_norm() before loading any T5 models when running in bf16.

Module Contents

Functions

NameDescription
patch_t5_layer_normReplace apex’s FusedRMSNorm with a native T5LayerNorm in the T5 module.

Data

_logger

API

nemo_automodel.shared.transformers_patches.patch_t5_layer_norm() -> None

Replace apex’s FusedRMSNorm with a native T5LayerNorm in the T5 module.

Apex’s FusedRMSNorm doesn’t support bfloat16, but the native T5LayerNorm handles it correctly by upcasting to fp32 internally for numerical stability. This must be called before loading any T5 models.

This function is idempotent and safe to call multiple times.

nemo_automodel.shared.transformers_patches._logger = logging.getLogger(__name__)