Release Notes Release 2.1

Key Features and Enhancements

  • [PyTorch] Made the API for fused optimizers (Adam and SGD) consistent with the PyTorch equivalents.

  • [PyTorch] Implemented probability permutation and mask-based permutation in MoE.

  • [PyTorch] Added the store_param_remainders argument for TE optimizers to save memory when storing FP32 master weights for BF16 model weights.

  • [Jax] Added support for THD attention input format for the flax modules.

Fixed Issues

  • [PyPI] Fixed an issue when TE is installed from PyPI in an environment where TE has already been installed from source. The wheel installation was incorrect, resulting in an application crash at runtime.

  • [PyTorch] Fixed an issue with QuantizedTensor types when executing operations such as chunk or split, which have different shapes for input and output.

  • [PyTorch] Made miscellaneous fixes to attention backend for execution on blackwell GPUs.

  • [PyTorch] Fixed a crash when using Context Parallelism with FP8 weights.

  • [PyTorch] Fixed a crash when using fused gradient accumulation with grouped GEMMs (MoE).

  • [Jax/Flax] Changed flax modules to use dtype to initialize their parameters while inferring compute type from the input data type.

Known Issues in This Release

There are no known issues in this release.

Breaking Changes in This Release

There are no breaking changes in this release.

Deprecated Features

  • [Jax] The fused_attn_thd API call is deprecated in favor of fused_attn, which supports THD format.

  • [Jax] The mask positional argument is deprecated in favor of sequence_descriptor.