Transformer Engine v2.15 Release Notes¶
Key Features and Enhancements¶
[PyTorch] Added support for Flash Attention 4. (#2432)
[PyTorch] Added support for MXFP8 attention. (#2719)
[PyTorch] Added support for QGeGLU activation both in
te.opsand the fused grouped MLP path using GEMM + activation fusion. (#2855)[PyTorch] Added support for per-token bias probability scaling both in
te.opsand the fused grouped MLP path using GEMM + activation fusion. (#2864)[PyTorch] Added support for NVFP4 weight quantization in the fused Adam optimizer. (#2797)
[PyTorch, Common] Added triton kernels to support mHC (Manifold-Constrained Hyper-Connections). (#2790)
[PyTorch, Common] Added support for dequantizing MXFP8 grouped tensors. (#2722)
[Common] Added support for unswizzling scaling factors. (#2837,#2732)
[PyTorch] Added Newton–Schulz orthogonalization via cuSOLVERMp for distributed orthogonalization workloads. (#2706)
[PyTorch] Added an
NVTE_BACKWARD_OVERRIDE=high_precision|dequantizedenvironment variable to control backward precision behavior. (#2644)[PyTorch] Added a feature to debug tools to allow tensor dumps before and after quantization for numerical debugging. (#2645)
[PyTorch] Optimized FP8 block-scaling AllGather for FSDP2 to reduce communication overhead. (#2789)
[PyTorch] Added an example demonstrating high-precision weight initialization with
fully_shard. (#2785)[PyTorch] Expanded fused grouped MLP support via
te.opsby lowering the weight dimension requirements to being divisible by 64 (previously 256). (#2856)[PyTorch] Added
torch.compilesupport for the MoE permute utility functions. (#2686)[Common, PyTorch] Improved the performance of NVFP4 quantization by refactoring the amax compute kernel. (#2820)
[JAX] Reduced THD seqlen and offset computation from
O(T·T)memory down toO(T)for long sequences. (#2522)[JAX] Added MXFP8 grouped quantize + GEMM support. (#2763)
Fixed Issues¶
[PyTorch] Fixed a numerical bug where stale columnwise weight data would be used for post-validation training steps. (#2929)
[PyTorch] Fixed redundant memory usage when using NVFP4 parameters. (#2834)
[JAX] Fixed the JAX extension build with
NVTE_UB_WITH_MPI=1. (#2835)[Common] Fixed a numerical bug for the MoE fused router for large top-K and expert counts. (#2821)
[Common] Fixed an illegal memory access in
register_user_buffer_collectiveon Ampere (and older) GPUs when using user buffers for COMM-GEMM overlap. (#2859)[Build] Fixed a build crash when compiling from source with
NVTE_CUDA_ARCHS=120. (#2832)
Known issues¶
[PyTorch] When building a grouped MLP module via
te.ops.Sequentialin order to use the GEMM + activation fusion, the kernel may produce non-deterministic results in the single grouped-weight case (i.e., when the environment variableNVTE_GROUPED_LINEAR_SINGLE_PARAMand the corresponding module argumentsingle_grouped_weightis set).[PyTorch] Enabling fused grouped MLP via
te.opsrequirescudnn-frontendlibrary version1.23.0. In case of issues please ensure that the right version ofCuTeDSLis correctly installed:
python -m pip uninstall -y \
cutlass \
nvidia-cutlass \
nvidia-cutlass-dsl \
nvidia-cutlass-dsl-libs-base \
nvidia-cutlass-dsl-libs-cu13 \
nvidia-cudnn-frontend
python -m pip install -U pip setuptools wheel
python -m pip install --no-cache-dir "nvidia-cutlass-dsl[cu13]==4.4.1"
python -m pip install --no-cache-dir "nvidia-cudnn-frontend[cutedsl]==1.23.0"
Breaking Changes in This Release¶
There are no breaking changes in this release.
Deprecated Features¶
There are no deprecated features in this release.