Transformer Engine v2.15 Release Notes

Key Features and Enhancements

  • [PyTorch] Added support for Flash Attention 4. (#2432)

  • [PyTorch] Added support for MXFP8 attention. (#2719)

  • [PyTorch] Added support for QGeGLU activation both in te.ops and the fused grouped MLP path using GEMM + activation fusion. (#2855)

  • [PyTorch] Added support for per-token bias probability scaling both in te.ops and the fused grouped MLP path using GEMM + activation fusion. (#2864)

  • [PyTorch] Added support for NVFP4 weight quantization in the fused Adam optimizer. (#2797)

  • [PyTorch, Common] Added triton kernels to support mHC (Manifold-Constrained Hyper-Connections). (#2790)

  • [PyTorch, Common] Added support for dequantizing MXFP8 grouped tensors. (#2722)

  • [Common] Added support for unswizzling scaling factors. (#2837,#2732)

  • [PyTorch] Added Newton–Schulz orthogonalization via cuSOLVERMp for distributed orthogonalization workloads. (#2706)

  • [PyTorch] Added an NVTE_BACKWARD_OVERRIDE=high_precision|dequantized environment variable to control backward precision behavior. (#2644)

  • [PyTorch] Added a feature to debug tools to allow tensor dumps before and after quantization for numerical debugging. (#2645)

  • [PyTorch] Optimized FP8 block-scaling AllGather for FSDP2 to reduce communication overhead. (#2789)

  • [PyTorch] Added an example demonstrating high-precision weight initialization with fully_shard. (#2785)

  • [PyTorch] Expanded fused grouped MLP support via te.ops by lowering the weight dimension requirements to being divisible by 64 (previously 256). (#2856)

  • [PyTorch] Added torch.compile support for the MoE permute utility functions. (#2686)

  • [Common, PyTorch] Improved the performance of NVFP4 quantization by refactoring the amax compute kernel. (#2820)

  • [JAX] Reduced THD seqlen and offset computation from O(T·T) memory down to O(T) for long sequences. (#2522)

  • [JAX] Added MXFP8 grouped quantize + GEMM support. (#2763)

Fixed Issues

  • [PyTorch] Fixed a numerical bug where stale columnwise weight data would be used for post-validation training steps. (#2929)

  • [PyTorch] Fixed redundant memory usage when using NVFP4 parameters. (#2834)

  • [JAX] Fixed the JAX extension build with NVTE_UB_WITH_MPI=1. (#2835)

  • [Common] Fixed a numerical bug for the MoE fused router for large top-K and expert counts. (#2821)

  • [Common] Fixed an illegal memory access in register_user_buffer_collective on Ampere (and older) GPUs when using user buffers for COMM-GEMM overlap. (#2859)

  • [Build] Fixed a build crash when compiling from source with NVTE_CUDA_ARCHS=120. (#2832)

Known issues

  • [PyTorch] When building a grouped MLP module via te.ops.Sequential in order to use the GEMM + activation fusion, the kernel may produce non-deterministic results in the single grouped-weight case (i.e., when the environment variable NVTE_GROUPED_LINEAR_SINGLE_PARAM and the corresponding module argument single_grouped_weight is set).

  • [PyTorch] Enabling fused grouped MLP via te.ops requires cudnn-frontend library version 1.23.0. In case of issues please ensure that the right version of CuTeDSL is correctly installed:

python -m pip uninstall -y \
  cutlass \
  nvidia-cutlass \
  nvidia-cutlass-dsl \
  nvidia-cutlass-dsl-libs-base \
  nvidia-cutlass-dsl-libs-cu13 \
  nvidia-cudnn-frontend
python -m pip install -U pip setuptools wheel
python -m pip install --no-cache-dir "nvidia-cutlass-dsl[cu13]==4.4.1"
python -m pip install --no-cache-dir "nvidia-cudnn-frontend[cutedsl]==1.23.0"

Breaking Changes in This Release

There are no breaking changes in this release.

Deprecated Features

There are no deprecated features in this release.