Release Notes Release 1.5

Key Features and Enhancements

  • [pyTorch] Added support for non-reentrant mode for activation recompute in the checkpoint API.

  • [pyTorch] Added support for rectangular matrices in the unfused softmax backend in order to support speculative decoding.

  • [pyTorch] Added the inference_params argument to the DotProductAttention API to support kv-caching.

  • [JAX] Added the DotProductAttention API.

  • [JAX] Expanded RoPE support using the rotary_pos_emb_group_method argument.

  • [paddle] Added support for RMSNorm.

  • [paddle] Added support for RoPE.

  • [paddle] Added support for SwiGLU.

Fixed Issues

  • [pyTorch] Fixed a numerical issue with storing weights in FP8 via the fp8_model_init API.

Known Issues in This Release

  • FlashAttention v2, which is a dependency of this release of Transformer Engine, has a known issue with excessive memory usage during installation (https://github.com/Dao-AILab/flash-attention/issues/358). You can work around this issue either by setting the environment variable MAX_JOBS=1 during Transformer Engine installation, or by installing FlashAttention v1 (e.g. by executing pip install flash-attn==1.0.9) before attempting to install Transformer Engine.

  • [pyTorch] FlashAttention v2.1 changed the behavior of the causal mask when performing cross-attention (see https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag for reference). In order for Transformer Engine to keep consistent behavior between versions and backends, FlashAttention is disabled for this use case (cross attention with casual masking) when 2.1+ version of FlashAttention is installed.

Breaking Changes in This Release

There are no breaking changes in this release.

Deprecated Features

  • [JAX] The arguments num_heads, dropout_rate, output_layernorm, apply_residual_connection_post_layernorm, and fuse_qkv are deprecated in the MultiHeadAttention API. They are replaced respectively with num_attention_heads, attention_dropout, input_layernorm, return_layernorm_output, and fused_qkv_params.

  • FlashAttention v1 is no longer supported in Transformer Engine. The minimum required version is v2.0.6.

Miscellaneous Changes

There are no miscellaneous changes in this release.