Release Notes Release 0.12.0 (BETA)

Key Features and Enhancements

  • [pyTorch] Added device option for all modules (with cpu and cuda as possible values), enabling initialization of the model of the CPU.

  • [pyTorch] Added MultiheadAttention module.

  • [pyTorch] DotProductAttention module exposes now the attn_mask_type parameter in its forward method, enabling easy switching between causal and non-causal execution (e.g. when switching between training and inference).

  • [JAX] Added support for FLAX 0.7.1.

  • [JAX] Added support for the fused attention with sequence lengths longer than 512.

  • [JAX] Added support for FSDP in FLAX and Praxis.

  • [JAX] Added support for FP8 execution in Praxis.

Fixed Issues

  • [pyTorch] Fixed an issue with the reproducibility of the results between runs with and without activation recomputation

  • [pyTorch] Fixed an issue where in some cases memory would be allocated on a wrong device during loading from the checkpoint (https://github.com/NVIDIA/TransformerEngine/issues/342).

  • [pyTorch] Fixed a crash when sequence parallelism is used with frozen weights.

  • [pyTorch] Fixed the behavior of LayerNorm and RMSNorm modules when running under AMP.

  • [pyTorch] Fixed an issue where in some cases using the cuDNN backend of the fused attention would corrupt the random number generator state.

Known Issues in This Release

  • FlashAttention v2, which is a dependency of this release of Transformer Engine, has a known issue with excessive memory usage during installation (https://github.com/Dao-AILab/flash-attention/issues/358). You can work around this issue either by setting the environment variable MAX_JOBS=1 during Transformer Engine installation, or by installing FlashAttention v1 (e.g. running pip install flash-attn==1.0.9) before attempting to install Transformer Engine.

Breaking Changes in This Release

  • There are no breaking changes in this release.

Deprecated Features

  • [pyTorch] The TransformerLayer arguments attention_softmax_in_fp32 and apply_query_key_layer_scaling are deprecated, and will be removed in a future release. The default behavior is as if those arguments were set to True.

  • [pyTorch] The DotProductAttention argument attn_mask_type has been moved to the forward method and is deprecated. It will be fully removed in the future release.