Release Notes Release 0.7.0 (BETA)

Key Features and Enhancements

Added initial support for JAX framework.

Added Flax modules for LayerNorm, DenseGeneral, LayerNormDenseGeneral, LayerNormMLP, RelativePositionalEmbeddings, and TransformerLayer.

Added support for 1D tensor parallelism.

Optimized performance of FP8 training using spatial parallelism in pyTorch.

Optimized performance of casting to FP8 for small matrices.

Added support for FP8 execution on Ada architecture. This support requires CUDA 12.1 and cuBLAS 12.1.3 or later.

Changed the length of the amax history window in the default FP8 recipe to 1024, the value recommended for very large model training.

Fixed Issues in This Release

Fixed an issue where training with Transformer Engine left a zombie process running.

Fixed an issue where passing the return_bias argument overrode the value of the bias argument.

Fixed an issue where in some cases Flash Attention was erroneously launched using an unsupported configuration on Ada architecture.

Known Issues in This Release

There are no known issues in this release.

Breaking Changes in This Release

There are no breaking changes in this release.

Deprecated Features

The TransformerLayer arguments attention_softmax_in_fp32 and apply_query_key_layer_scaling are deprecated, and will be removed in a future release. The default behavior is as if those arguments were set to True.