Release Notes Release 2.2

Key Features and Enhancements

  • [PyTorch] Added support for per-tensor current scaling recipe.

  • [PyTorch] Implemented cross-entropy loss with support for splitting computation across multiple devices.

  • [PyTorch] Added support for CPU offloading with Megatron-Core style distributed optimizers.

  • [C/PyTorch] Improved performance for P2P-based Tensor Parallel (TP) communication overlap.

  • [Jax] Added support for THD format with ring attention.

  • [C] Added multi-node support for NVIDIA® NVLink for TP overlap with userbuffers.

  • [PyTorch] Added support for KV cache for FusedAttention, FlashAttention, and UnfusedDotProductAttention backends.

  • [PyTorch] Improved bulk TP communication overlap by launching GEMMs on lower priority streams.

  • [Jax] Improved performance and memory usage for causal mask in the cuDNN attention backend.

Fixed Issues

  • [PyTorch] Fixed convergence when using context parallelism with a fused attention backend.

  • [PyTorch] Fixed a crash using GroupedLinear when the last input has no tokens.

  • [PyTorch] Made miscellaneous fixes to improve overall performance of the MXFP8 recipe.

  • [PyTorch] Reintroduced support for return_bias argument to all modules, which was silently ignored in v2.0 and v2.1.

  • [PyTorch] Reintroduced support for FP8 communication for overlapping reduce-scatter and GEMM when using TP overlap with userbuffers.

  • [PyTorch] Fixed gradient accumulation fusion in the LayerNormMLP module.

  • [C/PyTorch] Made miscellaneous numerical fixes to the fused attention backend.

  • [C] Avoided creating a new cublasLtHandle for every GEMM call to avoid memory leaks.

  • [Jax] Fixed shape and sharding inference in fused-attention C++ extension.

  • [Jax] Fixed an import error in the encoder example.

Known Issues in This Release

  • RTX 5090 is currently unsupported for FP8 execution. Support will be added in v2.3.0.

  • Transformer Engine may crash when it is installed via the PyPI registry but is run in an environment with CUDA version < 12.8. A temporary workaround is to install from source until the issue is fixed.

Breaking Changes in This Release

  • [PyTorch] The deprecated interval argument for the DelayedScaling recipe has been removed.

  • [PyTorch] There are multiple breaking changes in the InferenceParams class.

    • New arguments num_heads_kv, head_dim_k, and dtype are required during initialization.

    • The user must call a pre_step method to update the InferenceParams state.

    • The swap_key_value_dict method has been removed, as the step method now automatically reorders the key/value sequences according to their batch indices.

Deprecated Features

There are no deprecated features in this release.

Miscellaneous

  • [PyTorch] The minimum required PyTorch version is changed to 2.1.