Transformer Engine v2.14 Release Notes¶
Key Features and Enhancements¶
[PyTorch] Added multiple CPU overhead optimizations across the framework integration to reduce per-step Python/host overhead. (#2559) (#2724)
[C, PyTorch] Added BF16 and MXFP8 grouped GEMM support with on-device group sizes. (#2748) (#2669)
[PyTorch] Added a fused GEMM + SwiGLU grouped MLP for MXFP8 to accelerate MoE forward/backward. (#2769)
[PyTorch] Added support for a single-parameter
GroupedLinearconfiguration, where the weights of all experts are stored in a single parameter, which reduces CPU overhead. (#2731)[PyTorch] Added backwards-compatible checkpoint support for the new single-parameter
GroupedLinear. (#2761)[PyTorch] Extended the fused attention API to optionally return softmax
Statsalways andMaxwhenreturn_max_logit=True, exposing more cuDNN intermediates to users. (#2677)[PyTorch] Enabled SM120 support for the fused attention path when cuDNN >= 9.18.1 is available. (#2693)
[PyTorch] Added support for
MXFP8BlockScalingandFloat8BlockScalingquantized weight inFusedAdam. (#2753)[PyTorch] Added CUDA graph-compatible
multi_tensor_scale_tensorAPI in the optimizer. (#2594)[PyTorch] Enabled CUDA Graph capture of modules with CPU offloading. (#2435)
[PyTorch] Added support for non-FP32
params_dtypewhen using QK-normalization. (#2718)[PyTorch] Added precision debug-tools support for quantized model parameters. (#2141)
[JAX] Added a JAX-side API to invoke the fused MoE router kernels. (#2711)
[JAX] Integrated BF16 grouped GEMM with on-device group sizes. (#2680)
[JAX] Added a Collective GEMM (CGEMM) implementation with FP8 and MXFP8 support. (#2740)
[JAX] Added Shardy support to the Collective GEMM (CGEMM) path. (#2714)
[JAX] Improved the performance of the permutation kernels for the JAX 0.8.0 and newer. (#2741)
[C] Enabled the fused RMSNorm
dLN + addbackward path through cuDNN for faster fused-residual normalization. (#2778)[C] Added a grouped MXFP8 quantization kernel, including grouped dbias support. (#2738) (#2674)
[C] Enabled dequantization from an MXFP8 tensor that only carries column-wise data. (#2712)
[C/PyTorch] Improved the performance of the NVFP4 recipe by fusing row-cast / RHT / transpose / column-cast. (#2555)
[C] Made the number of Philox rounds for stochastic rounding configurable. (#2751)
[Documentation] Added a documentation page describing CPU offloading in Transformer Engine. (#2520)
[Documentation] Updated the documentation to describe the current cuDNN sliding-window attention support. (#2624)
[Documentation] Improved error messages across the C, PyTorch, and JAX layers. (#2705)
[Documentation] Added a custom-feature tutorial for the precision debug tools. (#2216)
[Documentation] Added documentation for the operator fuser API. (#2447)
[PyTorch, Documentation] Added end-to-end examples for
fused_adam,quantized_model_init, and FSDP2 usage. (#2698) (#2662)
Fixed Issues¶
[PyTorch] FSDP2 / Megatron-FSDP / DCP (distributed checkpointing): When model parameters are
DTensors, ensure optimizer states are alsoDTensors for correct sharded checkpoints. (#2795)[PyTorch] Fixed async DCP checkpointing for
Float8Tensorparameters. (#2721)[PyTorch] Fixed the issue with
cross_entropy_forwardproducing wrong answers for non-contiguous logits. (#2746)[PyTorch] Fixed the excessive memory usage issue when using operator fuser. (#2750)
[PyTorch] Fixed a precision-debug-tools crash when
tp_group=None. (#2733)[PyTorch] Fixed Flash Attention 3 API compatibility for the window-size parameters. (#2704)
[PyTorch] Fixed the initialization of the learnable
softmax_offsetparameter inDotProductAttentionto zero-initialization. (#2694)[PyTorch] Fixed the error with FP8 block scaling when sequence parallelism is enabled and local tensor dimensions are not divisible by 128. (#2637)
[PyTorch] Added a clear error when constructing
LayerNormLinearwith row-wise tensor parallelism (an unsupported configuration). Previously this configuration would fail with the CUDA error (#2688)[JAX] Fixed the performance issue with THD/BSHD segment-position generation. (#2823)
[JAX] Fixed the assertion error when using
from_segment_ids_and_pos()withvmap. (#2692)[JAX] Fixed the performance issue for models using both FSDP and EP. (#2649)
[JAX] Changed the dtype of the intermediate-result aval in
fused_topk_and_score_function_fwdtofp32to avoid precision loss. (#2752)[C] Fixed an incorrect MNNVL fabric-availability check that misreported support on some systems. (#2626)
[C/PyTorch] Fixed score normalization in
fused_score_for_moe_aux_losswhentopk == 1. (#2720)[PyTorch] Fixed the possible precision loss when copying from the quantized tensor to the high precision tensor. (#2120, #2673)
Breaking Changes in This Release¶
[JAX] GSPMD partitioning rules are no longer tested and will now warn on use; users on JAX with GSPMD should migrate to Shardy. (#2702)
Deprecated Features¶
There are no deprecated features in this release.