.. include:: /content/common.rsts Release Notes |ndash| Release 2.2 !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Key Features and Enhancements @@@@@@@@@@@@@@@@@@@@@@@@@@@@@ - [PyTorch] Added support for per-tensor current scaling recipe. - [PyTorch] Implemented cross-entropy loss with support for splitting computation across multiple devices. - [PyTorch] Added support for CPU offloading with Megatron-Core style distributed optimizers. - [C/PyTorch] Improved performance for P2P-based Tensor Parallel (TP) communication overlap. - [Jax] Added support for THD format with ring attention. - [C] Added multi-node support for NVIDIA\ :sup:`®` NVLink for TP overlap with userbuffers. - [PyTorch] Added support for KV cache for FusedAttention, FlashAttention, and UnfusedDotProductAttention backends. - [PyTorch] Improved bulk TP communication overlap by launching GEMMs on lower priority streams. - [Jax] Improved performance and memory usage for causal mask in the cuDNN attention backend. Fixed Issues @@@@@@@@@@@@ - [PyTorch] Fixed convergence when using context parallelism with a fused attention backend. - [PyTorch] Fixed a crash using GroupedLinear when the last input has no tokens. - [PyTorch] Made miscellaneous fixes to improve overall performance of the MXFP8 recipe. - [PyTorch] Reintroduced support for *return_bias* argument to all modules, which was silently ignored in v2.0 and v2.1. - [PyTorch] Reintroduced support for FP8 communication for overlapping reduce-scatter and GEMM when using TP overlap with userbuffers. - [PyTorch] Fixed gradient accumulation fusion in the ``LayerNormMLP`` module. - [C/PyTorch] Made miscellaneous numerical fixes to the fused attention backend. - [C] Avoided creating a new ``cublasLtHandle`` for every GEMM call to avoid memory leaks. - [Jax] Fixed shape and sharding inference in fused-attention C++ extension. - [Jax] Fixed an import error in the encoder example. Known Issues in This Release @@@@@@@@@@@@@@@@@@@@@@@@@@@@ - RTX 5090 is currently unsupported for FP8 execution. Support will be added in v2.3.0. - Transformer Engine may crash when it is installed via the PyPI registry but is run in an environment with CUDA version < 12.8. A temporary workaround is to install from source until the issue is fixed. Breaking Changes in This Release @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ - [PyTorch] The deprecated `interval` argument for the DelayedScaling recipe has been removed. - [PyTorch] There are multiple breaking changes in the ``InferenceParams`` class. - New arguments `num_heads_kv`, `head_dim_k`, and `dtype` are required during initialization. - The user must call a ``pre_step`` method to update the `InferenceParams` state. - The ``swap_key_value_dict`` method has been removed, as the ``step`` method now automatically reorders the key/value sequences according to their batch indices. Deprecated Features @@@@@@@@@@@@@@@@@@@ There are no deprecated features in this release. Miscellaneous @@@@@@@@@@@@@ - [PyTorch] The minimum required PyTorch version is changed to 2.1.