Transformer Engine v2.14 Release Notes

Key Features and Enhancements

  • [PyTorch] Added multiple CPU overhead optimizations across the framework integration to reduce per-step Python/host overhead. (#2559) (#2724)

  • [C, PyTorch] Added BF16 and MXFP8 grouped GEMM support with on-device group sizes. (#2748) (#2669)

  • [PyTorch] Added a fused GEMM + SwiGLU grouped MLP for MXFP8 to accelerate MoE forward/backward. (#2769)

  • [PyTorch] Added support for a single-parameter GroupedLinear configuration, where the weights of all experts are stored in a single parameter, which reduces CPU overhead. (#2731)

  • [PyTorch] Added backwards-compatible checkpoint support for the new single-parameter GroupedLinear. (#2761)

  • [PyTorch] Extended the fused attention API to optionally return softmax Stats always and Max when return_max_logit=True, exposing more cuDNN intermediates to users. (#2677)

  • [PyTorch] Enabled SM120 support for the fused attention path when cuDNN >= 9.18.1 is available. (#2693)

  • [PyTorch] Added support for MXFP8BlockScaling and Float8BlockScaling quantized weight in FusedAdam. (#2753)

  • [PyTorch] Added CUDA graph-compatible multi_tensor_scale_tensor API in the optimizer. (#2594)

  • [PyTorch] Enabled CUDA Graph capture of modules with CPU offloading. (#2435)

  • [PyTorch] Added support for non-FP32 params_dtype when using QK-normalization. (#2718)

  • [PyTorch] Added precision debug-tools support for quantized model parameters. (#2141)

  • [JAX] Added a JAX-side API to invoke the fused MoE router kernels. (#2711)

  • [JAX] Integrated BF16 grouped GEMM with on-device group sizes. (#2680)

  • [JAX] Added a Collective GEMM (CGEMM) implementation with FP8 and MXFP8 support. (#2740)

  • [JAX] Added Shardy support to the Collective GEMM (CGEMM) path. (#2714)

  • [JAX] Improved the performance of the permutation kernels for the JAX 0.8.0 and newer. (#2741)

  • [C] Enabled the fused RMSNorm dLN + add backward path through cuDNN for faster fused-residual normalization. (#2778)

  • [C] Added a grouped MXFP8 quantization kernel, including grouped dbias support. (#2738) (#2674)

  • [C] Enabled dequantization from an MXFP8 tensor that only carries column-wise data. (#2712)

  • [C/PyTorch] Improved the performance of the NVFP4 recipe by fusing row-cast / RHT / transpose / column-cast. (#2555)

  • [C] Made the number of Philox rounds for stochastic rounding configurable. (#2751)

  • [Documentation] Added a documentation page describing CPU offloading in Transformer Engine. (#2520)

  • [Documentation] Updated the documentation to describe the current cuDNN sliding-window attention support. (#2624)

  • [Documentation] Improved error messages across the C, PyTorch, and JAX layers. (#2705)

  • [Documentation] Added a custom-feature tutorial for the precision debug tools. (#2216)

  • [Documentation] Added documentation for the operator fuser API. (#2447)

  • [PyTorch, Documentation] Added end-to-end examples for fused_adam, quantized_model_init, and FSDP2 usage. (#2698) (#2662)

Fixed Issues

  • [PyTorch] FSDP2 / Megatron-FSDP / DCP (distributed checkpointing): When model parameters are DTensors, ensure optimizer states are also DTensors for correct sharded checkpoints. (#2795)

  • [PyTorch] Fixed async DCP checkpointing for Float8Tensor parameters. (#2721)

  • [PyTorch] Fixed the issue with cross_entropy_forward producing wrong answers for non-contiguous logits. (#2746)

  • [PyTorch] Fixed the excessive memory usage issue when using operator fuser. (#2750)

  • [PyTorch] Fixed a precision-debug-tools crash when tp_group=None. (#2733)

  • [PyTorch] Fixed Flash Attention 3 API compatibility for the window-size parameters. (#2704)

  • [PyTorch] Fixed the initialization of the learnable softmax_offset parameter in DotProductAttention to zero-initialization. (#2694)

  • [PyTorch] Fixed the error with FP8 block scaling when sequence parallelism is enabled and local tensor dimensions are not divisible by 128. (#2637)

  • [PyTorch] Added a clear error when constructing LayerNormLinear with row-wise tensor parallelism (an unsupported configuration). Previously this configuration would fail with the CUDA error (#2688)

  • [JAX] Fixed the performance issue with THD/BSHD segment-position generation. (#2823)

  • [JAX] Fixed the assertion error when using from_segment_ids_and_pos() with vmap. (#2692)

  • [JAX] Fixed the performance issue for models using both FSDP and EP. (#2649)

  • [JAX] Changed the dtype of the intermediate-result aval in fused_topk_and_score_function_fwd to fp32 to avoid precision loss. (#2752)

  • [C] Fixed an incorrect MNNVL fabric-availability check that misreported support on some systems. (#2626)

  • [C/PyTorch] Fixed score normalization in fused_score_for_moe_aux_loss when topk == 1. (#2720)

  • [PyTorch] Fixed the possible precision loss when copying from the quantized tensor to the high precision tensor. (#2120, #2673)

Breaking Changes in This Release

  • [JAX] GSPMD partitioning rules are no longer tested and will now warn on use; users on JAX with GSPMD should migrate to Shardy. (#2702)

Deprecated Features

There are no deprecated features in this release.