# Transformer Engine v2.13 Release Notes ## Key Features and Enhancements - Added detailed documentation for low precision training with Transformer Engine, covering FP8, MXFP8, NVFP4, and other quantization recipes with examples for both PyTorch and JAX. ([#2343](https://github.com/NVIDIA/TransformerEngine/pull/2343)). - [Build] Added the `NVTE_BUILD_USE_NVIDIA_WHEELS` environment variable to allow building TE using CUDA headers from PyPI NVIDIA wheels instead of a system CUDA installation. ([#2623](https://github.com/NVIDIA/TransformerEngine/pull/2623)) - [C] Enabled deterministic FP8 fused attention on Blackwell (SM100) GPUs. ([#2621](https://github.com/NVIDIA/TransformerEngine/pull/2621)) - [C] Updated cuBLASMp integration to version 0.8.0, replacing the nvshmem dependency with NCCL-based symmetric memory. ([#2661](https://github.com/NVIDIA/TransformerEngine/pull/2661)) - [C] Added MXFP8 quantization kernels for grouped tensors used in MoE, with fused scale-factor swizzling for improved performance. ([#2586](https://github.com/NVIDIA/TransformerEngine/pull/2586), [#2630](https://github.com/NVIDIA/TransformerEngine/pull/2630)) - [C] Added NVFP4 quantization kernels for grouped tensors used in MoE models. ([#2655](https://github.com/NVIDIA/TransformerEngine/pull/2655)) - [C] Reduced cuDNN graph recompilations in THD fused attention by rounding large batch sizes to 512-element increments. ([#2653](https://github.com/NVIDIA/TransformerEngine/pull/2653)) - [C] Added the `sqrtsoftplus` scoring function to the fused MoE router and improved router kernel performance on Blackwell GPUs. ([#2633](https://github.com/NVIDIA/TransformerEngine/pull/2633), [#2683](https://github.com/NVIDIA/TransformerEngine/pull/2683)) - [PyTorch] Introduced `GroupedTensor`, enabling MoE expert weights to be stored as a single contiguous allocation while remaining individually addressable. ([#2654](https://github.com/NVIDIA/TransformerEngine/pull/2654)) - [PyTorch] Added the fusible `GroupedLinear` and `ScaledSwiGLU` ops for building fully fused MoE grouped MLP pipelines. ([#2664](https://github.com/NVIDIA/TransformerEngine/pull/2664)) - [PyTorch] Added the `register_forward_fusion` and `register_backward_fusion` APIs, allowing users to define and register custom operator fusion patterns. ([#2597](https://github.com/NVIDIA/TransformerEngine/pull/2597)) - [PyTorch] Added the `get_backward_dw_params` API to TE modules, fixing weight gradient hook management when using wgrad CUDA Graphs with Megatron-LM. ([#2614](https://github.com/NVIDIA/TransformerEngine/pull/2614)) - [PyTorch] Fixed fused attention bias dimension handling and extended `dbias` support to additional bias shapes (`b1ss`, `bhss`, `11ss`, `111s`). ([#2537](https://github.com/NVIDIA/TransformerEngine/pull/2537)) - [PyTorch] Reduced peak memory usage in the fused Adam optimizer by fusing BF16 momentum scaling directly into CUDA kernels, also enabling CUDA Graph capture for this path. ([#2632](https://github.com/NVIDIA/TransformerEngine/pull/2632)) - [PyTorch] Added the sigmoid-gated GLU activation (`activation="glu"`) to `LayerNormMLP` and `TransformerLayer`. ([#2656](https://github.com/NVIDIA/TransformerEngine/pull/2656)) - [PyTorch] Extended debug statistics tracking to NVFP4 quantization (underflow and MSE metrics), and gracefully skipped stat logging for layers not using quantization. ([#2296](https://github.com/NVIDIA/TransformerEngine/pull/2296), [#2652](https://github.com/NVIDIA/TransformerEngine/pull/2652)) - [PyTorch] Fixed CUDA Graph capture for Megatron-Core vision encoder models. ([#2657](https://github.com/NVIDIA/TransformerEngine/pull/2657)) - [JAX] Added an experimental `inspect_array` debugging utility for dumping tensor snapshots during multi-GPU execution. ([#2651](https://github.com/NVIDIA/TransformerEngine/pull/2651)) - [JAX] Fixed MoE permutation to mask padding tokens correctly and handle tensor sizes under expert parallelism. ([#2672](https://github.com/NVIDIA/TransformerEngine/pull/2672)) - [JAX] MoE permutation now always returns `tokens_per_expert`, required for ragged all-to-all communication in expert parallelism. ([#2613](https://github.com/NVIDIA/TransformerEngine/pull/2613)) ## Fixed Issues - [C] Fixed incorrect results from the `exp2f_rcp` fast-math helper when inputs are NaN or have biased exponent 254. ([#2647](https://github.com/NVIDIA/TransformerEngine/pull/2647)) - [C] Fixed a race condition in Randomized Hadamard Transform amax kernels where a missing memory fence could cause incorrect amax values. ([#2695](https://github.com/NVIDIA/TransformerEngine/pull/2695)) - [PyTorch] Fixed the TE Llama example to work with HuggingFace Transformers 4.57+, which changed decoder layer output conventions. ([#2572](https://github.com/NVIDIA/TransformerEngine/pull/2572)) - [Build] Fixed `TypeError` during build when NCCL is installed from PyPI as a namespace package without a `__file__` attribute. ([#2580](https://github.com/NVIDIA/TransformerEngine/pull/2580)) - [Build] Fixed `ModuleNotFoundError` when installing from cached source distributions (e.g., via `uv`) by including `build_tools` in `MANIFEST.in`. ([#2684](https://github.com/NVIDIA/TransformerEngine/pull/2684)) ## Breaking Changes in This Release - [C] Removed the deprecated packed fused attention C APIs (`nvte_fused_attn_{fwd,bwd}_{qkvpacked,kvpacked}`); users must migrate to the non-packed API variants. ([#2696](https://github.com/NVIDIA/TransformerEngine/pull/2696)) - Versions of cuBLASMp prior to 0.8.0 are no longer supported. ## Deprecated Features No features are deprecated in this release.