Release Notes – Release 2.2¶
Key Features and Enhancements¶
[PyTorch] Added support for per-tensor current scaling recipe.
[PyTorch] Implemented cross-entropy loss with support for splitting computation across multiple devices.
[PyTorch] Added support for CPU offloading with Megatron-Core style distributed optimizers.
[C/PyTorch] Improved performance for P2P-based Tensor Parallel (TP) communication overlap.
[Jax] Added support for THD format with ring attention.
[C] Added multi-node support for NVIDIA® NVLink for TP overlap with userbuffers.
[PyTorch] Added support for KV cache for FusedAttention, FlashAttention, and UnfusedDotProductAttention backends.
[PyTorch] Improved bulk TP communication overlap by launching GEMMs on lower priority streams.
[Jax] Improved performance and memory usage for causal mask in the cuDNN attention backend.
Fixed Issues¶
[PyTorch] Fixed convergence when using context parallelism with a fused attention backend.
[PyTorch] Fixed a crash using GroupedLinear when the last input has no tokens.
[PyTorch] Made miscellaneous fixes to improve overall performance of the MXFP8 recipe.
[PyTorch] Reintroduced support for return_bias argument to all modules, which was silently ignored in v2.0 and v2.1.
[PyTorch] Reintroduced support for FP8 communication for overlapping reduce-scatter and GEMM when using TP overlap with userbuffers.
[PyTorch] Fixed gradient accumulation fusion in the
LayerNormMLP
module.[C/PyTorch] Made miscellaneous numerical fixes to the fused attention backend.
[C] Avoided creating a new
cublasLtHandle
for every GEMM call to avoid memory leaks.[Jax] Fixed shape and sharding inference in fused-attention C++ extension.
[Jax] Fixed an import error in the encoder example.
Known Issues in This Release¶
RTX 5090 is currently unsupported for FP8 execution. Support will be added in v2.3.0.
Transformer Engine may crash when it is installed via the PyPI registry but is run in an environment with CUDA version < 12.8. A temporary workaround is to install from source until the issue is fixed.
Breaking Changes in This Release¶
[PyTorch] The deprecated interval argument for the DelayedScaling recipe has been removed.
[PyTorch] There are multiple breaking changes in the
InferenceParams
class.New arguments num_heads_kv, head_dim_k, and dtype are required during initialization.
The user must call a
pre_step
method to update the InferenceParams state.The
swap_key_value_dict
method has been removed, as thestep
method now automatically reorders the key/value sequences according to their batch indices.
Deprecated Features¶
There are no deprecated features in this release.
Miscellaneous¶
[PyTorch] The minimum required PyTorch version is changed to 2.1.