Release Notes – Release 2.8¶
Key Features and Enhancements¶
[C][PyTorch] Added support for the NVFP4 training recipe.
[C][PyTorch] Added support for FP8 attention with the current scaling recipe.
[PyTorch] Added support for mixing recipes for different modules when using the
make_graphed_callables
function.[C] Added 8-bit RNG support to the dropout kernel.
[C][PyTorch] Added the
nvte_rmsnorm_bwd_add
function to the C API and added support for fusing RMSNorm and add operation in the sequential Transformer Engine operations API.[C] Added more robust error checking and handling when calling CUDA and driver APIs.
[C][PyTorch] Added support for using FP8 and non-FP8 quantization modes in the same model when overlapping tensor parallel communication and GEMM using userbuffers.
[PyTorch] Added support for the
qgeglu
andsreglu
activation in the Transformer Engine fused operations API and theLayerNormMLP
module.[C][PyTorch] Added support for FP8 GEMM output for the MXFP8 and current scaling recipe.
[PyTorch] Added support for FP8 all-gather when using Tensor Parallel with the
GroupedLinear
module.[PyTorch] Added support for the current scaling FP8 recipe for module export via ONNX.
[PyTorch] Made miscellaneous improvements to MoE workloads to reduce CPU overhead.
[PyTorch] Improved performance of CUDA graphs using FP8 weight cache in quantization kernels.
[PyTorch] Added support for FlashAttention v3 for MLA with context parallelism.
[PyTorch] Added support activation CPU offloading for Transformer Engine sequential operations API.
[PyTorch] Made miscellaneous performance improvements when using RoPE (rotary positional embeddings)
[C] Added support for BF16 and FP32 inputs to the kernel that calculates auxiliary loss for MoE.
[C] Added support for sink attention from cuDNN.
[Jax] Fused swizzling operation for the scaling factor inverse and transpose calculation of the data.
Fixed Issues¶
[Jax] Fixed a crash when the user calls
global_shard_guard
before setting the JAX mesh.[Jax] Fixed an issue in the mesh logic such that if an axis is undefined in the mesh, Transformer Engine still applies the sharding constraint for the given tensor on other axes instead of skipping.
[Jax] Fixed a crash in
GroupedScaledTensor
due to incorrect arguments being passed.[PyTorch] Fixed a bug in the cross entropy loss kernel that resulted in vanishing gradients.
[C][PyTorch] Fixed incorrect calculation of tensor parallel rank when using userbuffers.
[PyTorch] Fixed redundant memory overheads when using FP8 all-gather with sequence parallelism.
Known Issues in This Release¶
[PyTorch] For distributed workloads using the
Float8CurrentScaling
recipe without FP8 attention, there are some performance overheads due to redundant amax reductions across the tensor parallel and context parallel groups.This issue has been fixed (https://github.com/NVIDIA/TransformerEngine/pull/2234), and will be available in the next release (v2.9).
As a workaround, you can run the workload with
export NVTE_DPA_FP8_RECIPE="F16"
in the environment.
Breaking Changes in This Release¶
There are no breaking changes in this release.
Deprecated Features¶
There are no deprecated features in this release.
Miscellaneous¶
There are no miscellaneous issues in this release.