# Transformer Engine v2.14 Release Notes ## Key Features and Enhancements - [PyTorch] Added multiple CPU overhead optimizations across the framework integration to reduce per-step Python/host overhead. ([#2559](https://github.com/NVIDIA/TransformerEngine/pull/2559)) ([#2724](https://github.com/NVIDIA/TransformerEngine/pull/2724)) - [C, PyTorch] Added BF16 and MXFP8 grouped GEMM support with on-device group sizes. ([#2748](https://github.com/NVIDIA/TransformerEngine/pull/2748)) ([#2669](https://github.com/NVIDIA/TransformerEngine/pull/2669)) - [PyTorch] Added a fused GEMM + SwiGLU grouped MLP for MXFP8 to accelerate MoE forward/backward. ([#2769](https://github.com/NVIDIA/TransformerEngine/pull/2769)) - [PyTorch] Added support for a single-parameter `GroupedLinear` configuration, where the weights of all experts are stored in a single parameter, which reduces CPU overhead. ([#2731](https://github.com/NVIDIA/TransformerEngine/pull/2731)) - [PyTorch] Added backwards-compatible checkpoint support for the new single-parameter `GroupedLinear`. ([#2761](https://github.com/NVIDIA/TransformerEngine/pull/2761)) - [PyTorch] Extended the fused attention API to optionally return softmax `Stats` always and `Max` when `return_max_logit=True`, exposing more cuDNN intermediates to users. ([#2677](https://github.com/NVIDIA/TransformerEngine/pull/2677)) - [PyTorch] Enabled SM120 support for the fused attention path when cuDNN >= 9.18.1 is available. ([#2693](https://github.com/NVIDIA/TransformerEngine/pull/2693)) - [PyTorch] Added support for `MXFP8BlockScaling` and `Float8BlockScaling` quantized weight in `FusedAdam`. ([#2753](https://github.com/NVIDIA/TransformerEngine/pull/2753)) - [PyTorch] Added CUDA graph-compatible `multi_tensor_scale_tensor` API in the optimizer. ([#2594](https://github.com/NVIDIA/TransformerEngine/pull/2594)) - [PyTorch] Enabled CUDA Graph capture of modules with CPU offloading. ([#2435](https://github.com/NVIDIA/TransformerEngine/pull/2435)) - [PyTorch] Added support for non-FP32 `params_dtype` when using QK-normalization. ([#2718](https://github.com/NVIDIA/TransformerEngine/pull/2718)) - [PyTorch] Added precision debug-tools support for quantized model parameters. ([#2141](https://github.com/NVIDIA/TransformerEngine/pull/2141)) - [JAX] Added a JAX-side API to invoke the fused MoE router kernels. ([#2711](https://github.com/NVIDIA/TransformerEngine/pull/2711)) - [JAX] Integrated BF16 grouped GEMM with on-device group sizes. ([#2680](https://github.com/NVIDIA/TransformerEngine/pull/2680)) - [JAX] Added a Collective GEMM (CGEMM) implementation with FP8 and MXFP8 support. ([#2740](https://github.com/NVIDIA/TransformerEngine/pull/2740)) - [JAX] Added Shardy support to the Collective GEMM (CGEMM) path. ([#2714](https://github.com/NVIDIA/TransformerEngine/pull/2714)) - [JAX] Improved the performance of the permutation kernels for the JAX 0.8.0 and newer. ([#2741](https://github.com/NVIDIA/TransformerEngine/pull/2741)) - [C] Enabled the fused RMSNorm `dLN + add` backward path through cuDNN for faster fused-residual normalization. ([#2778](https://github.com/NVIDIA/TransformerEngine/pull/2778)) - [C] Added a grouped MXFP8 quantization kernel, including grouped dbias support. ([#2738](https://github.com/NVIDIA/TransformerEngine/pull/2738)) ([#2674](https://github.com/NVIDIA/TransformerEngine/pull/2674)) - [C] Enabled dequantization from an MXFP8 tensor that only carries column-wise data. ([#2712](https://github.com/NVIDIA/TransformerEngine/pull/2712)) - [C/PyTorch] Improved the performance of the NVFP4 recipe by fusing row-cast / RHT / transpose / column-cast. ([#2555](https://github.com/NVIDIA/TransformerEngine/pull/2555)) - [C] Made the number of Philox rounds for stochastic rounding configurable. ([#2751](https://github.com/NVIDIA/TransformerEngine/pull/2751)) - [Documentation] Added a documentation page describing CPU offloading in Transformer Engine. ([#2520](https://github.com/NVIDIA/TransformerEngine/pull/2520)) - [Documentation] Updated the documentation to describe the current cuDNN sliding-window attention support. ([#2624](https://github.com/NVIDIA/TransformerEngine/pull/2624)) - [Documentation] Improved error messages across the C, PyTorch, and JAX layers. ([#2705](https://github.com/NVIDIA/TransformerEngine/pull/2705)) - [Documentation] Added a custom-feature tutorial for the precision debug tools. ([#2216](https://github.com/NVIDIA/TransformerEngine/pull/2216)) - [Documentation] Added documentation for the operator fuser API. ([#2447](https://github.com/NVIDIA/TransformerEngine/pull/2447)) - [PyTorch, Documentation] Added end-to-end examples for `fused_adam`, `quantized_model_init`, and FSDP2 usage. ([#2698](https://github.com/NVIDIA/TransformerEngine/pull/2698)) ([#2662](https://github.com/NVIDIA/TransformerEngine/pull/2662)) ## Fixed Issues - [PyTorch] FSDP2 / Megatron-FSDP / DCP (distributed checkpointing): When model parameters are `DTensor`s, ensure optimizer states are also `DTensor`s for correct sharded checkpoints. ([#2795](https://github.com/NVIDIA/TransformerEngine/pull/2795)) - [PyTorch] Fixed async DCP checkpointing for `Float8Tensor` parameters. ([#2721](https://github.com/NVIDIA/TransformerEngine/pull/2721)) - [PyTorch] Fixed the issue with `cross_entropy_forward` producing wrong answers for non-contiguous logits. ([#2746](https://github.com/NVIDIA/TransformerEngine/pull/2746)) - [PyTorch] Fixed the excessive memory usage issue when using operator fuser. ([#2750](https://github.com/NVIDIA/TransformerEngine/pull/2750)) - [PyTorch] Fixed a precision-debug-tools crash when `tp_group=None`. ([#2733](https://github.com/NVIDIA/TransformerEngine/pull/2733)) - [PyTorch] Fixed Flash Attention 3 API compatibility for the window-size parameters. ([#2704](https://github.com/NVIDIA/TransformerEngine/pull/2704)) - [PyTorch] Fixed the initialization of the learnable `softmax_offset` parameter in `DotProductAttention` to zero-initialization. ([#2694](https://github.com/NVIDIA/TransformerEngine/pull/2694)) - [PyTorch] Fixed the error with FP8 block scaling when sequence parallelism is enabled and local tensor dimensions are not divisible by 128. ([#2637](https://github.com/NVIDIA/TransformerEngine/pull/2637)) - [PyTorch] Added a clear error when constructing `LayerNormLinear` with row-wise tensor parallelism (an unsupported configuration). Previously this configuration would fail with the CUDA error ([#2688](https://github.com/NVIDIA/TransformerEngine/pull/2688)) - [JAX] Fixed the performance issue with THD/BSHD segment-position generation. ([#2823](https://github.com/NVIDIA/TransformerEngine/pull/2823)) - [JAX] Fixed the assertion error when using `from_segment_ids_and_pos()` with `vmap`. ([#2692](https://github.com/NVIDIA/TransformerEngine/pull/2692)) - [JAX] Fixed the performance issue for models using both FSDP and EP. ([#2649](https://github.com/NVIDIA/TransformerEngine/pull/2649)) - [JAX] Changed the dtype of the intermediate-result aval in `fused_topk_and_score_function_fwd` to `fp32` to avoid precision loss. ([#2752](https://github.com/NVIDIA/TransformerEngine/pull/2752)) - [C] Fixed an incorrect MNNVL fabric-availability check that misreported support on some systems. ([#2626](https://github.com/NVIDIA/TransformerEngine/pull/2626)) - [C/PyTorch] Fixed score normalization in `fused_score_for_moe_aux_loss` when `topk == 1`. ([#2720](https://github.com/NVIDIA/TransformerEngine/pull/2720)) - [PyTorch] Fixed the possible precision loss when copying from the quantized tensor to the high precision tensor. ([#2120](https://github.com/NVIDIA/TransformerEngine/pull/2120), [#2673](https://github.com/NVIDIA/TransformerEngine/pull/2673)) ## Breaking Changes in This Release - [JAX] GSPMD partitioning rules are no longer tested and will now warn on use; users on JAX with GSPMD should migrate to Shardy. ([#2702](https://github.com/NVIDIA/TransformerEngine/pull/2702)) ## Deprecated Features There are no deprecated features in this release.