JAX Release 23.10

JAX Release 23.10 (PDF)

The NVIDIA container image for JAX, release 23.10 is available on NGC.

Contents of the JAX container

This container image contains the complete source of the version of JAX in /opt/jax-source. It is prebuilt and installed within the container image. The container also includes the following:

Driver Requirements

Release 23.10 is based on CUDA 12.2.0, which requires NVIDIA Driver release 535 or later. However, if you are running on a data center GPU (for example, T4 or any other data center GPU), you can use NVIDIA driver release 450.51 (or later R450), 470.57 (or later R470), 510.47 (or later R510), or 525.85 (or later R525), or 535.86 (or later R535). The CUDA driver's compatibility package only supports particular drivers. Thus, users should upgrade from all R418, R440, R460, and R520 drivers, which are not forward-compatible with CUDA 12.2. For a complete list of supported drivers, see the CUDA Application Compatibility topic. For more information, see CUDA Compatibility and Upgrades.

GPU Requirements

Release 23.10 supports CUDA compute capability 6.0 and later. This corresponds to GPUs in the NVIDIA Pascal, NVIDIA Volta™, NVIDIA Turing™, NVIDIA Ampere architecture, and NVIDIA Hopper™ architecture families. For a list of GPUs to which this compute capability corresponds, see CUDA GPUs. For additional support details, see Deep Learning Frameworks Support Matrix.

Key Features and Enhancements

This JAX release includes the following key features and enhancements.

Announcements

  • Transformer Engine is a library for accelerating Transformer models on NVIDIA GPUs. It includes support for 8-bit floating point (FP8) precision on Hopper GPUs which provides better training and inference performance with lower memory utilization. Transformer Engine also includes a collection of highly optimized modules for popular Transformer architectures and an automatic mixed precision-like API that can be used seamlessly with your JAX code.

NVIDIA JAX Container Versions

The following table shows what versions of CUDA and JAX are supported in each of the NVIDIA containers for JAX.

Container Version CUDA Toolkit JAX
23.10 NVIDIA CUDA 12.2.0 0.4.17
23.08 NVIDIA CUDA 12.1.1 0.4.14

NVIDIA Tx5 and Paxml Container Versions

The following table shows what versions of CUDA and JAX are supported in each of the NVIDIA containers for JAX.

  Container Version CUDA Toolkit JAX
Paxml 23.10 NVIDIA CUDA 12.2.2 0.4.17.dev20231010
23.08 NVIDIA CUDA 12.1.1 0.4.14 (579808d98)
T5x 23.10 NVIDIA CUDA 12.2.2 0.4.17.dev20231010
23.08 NVIDIA CUDA 12.1.1 0.4.14 (603eeb190)


The JAX version in the Paxml and T5x containers are development versions. See /opt/jax-source and /opt/xla-source for the version used in each container.

Inspecting Source code in NVIDIA T5x and Paxml Containers

If you would like to inspect the pax’s source code (paxml and praxis) to learn more about what is being run, you can do so by inspecting the source within the nvcr.io/nvidia/jax:23.10-paxml-py3 container. Their locations within the container are:

  • Paxml: /opt/paxml
  • Praxis: /opt/praxis

Similarly, for t5x’s source code in nvcr.io/nvidia/jax:23.10-t5x-py3:

  • t5x: /opt/t5x

JAX Toolbox Examples

The JAX Toolbox examples focus on achieving the best performance and convergence from NVIDIA Hopper and NVIDIA Ampere architecture tensor cores by using the latest deep learning example networks and model scripts for training. These examples are tested against a nightly CI as well as each NGC container release to ensure consistent accuracy and performance over time.

  • GPT model: This is a decoder-only LLM architecture used in many modern text applications.This model script is available on Github. You can also find the latest performance and secured convergence on the model card on Github.

  • T5 model: This model was introduced in the Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer paper. This is a popular model for sequence-to-sequence modeling tasks involving text.This model script is available on Github. You can also find the latest performance and secured convergence on the model card on Github.
  • ViT model: The Vision Transformer was introduced in the An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale paper. This is a popular computer vision model used for image classification, but is also a foundational model for GenAI applications such as Multimodal Language Models. The ViT model also showcases NVIDIA DALI data pipelines. This is particularly helpful for smaller models where data pipelines tend to be the bottleneck. The model script is available on Github. You can also find the latest performance and secured convergence on the model card on Github.

Known Issues

  • Pipeline parallelism is not supported with NVIDIA Transformer Engine enabled in the Paxml container.
  • There could be random failures when running JAX in single-process with multi-GPU. This is a known issue due to the XLA version in the 23.10 containers. The issue will be resolved in the next release and the recommendation is to run with multi-processing enabled (1 GPU per process).
  • There is a known degradation in accuracy in the Paxml container when training with enabling dropout in transformer layers from TransformerEngine. This issue will be resolved in a future container release.
  • Performance degradation:
    • There is a 10% performance regression on Paxml’s 126M parameter model, trained with BF16 on A100s, when NVIDIA Transformer Engine is disabled. This will be fixed in a future release.
    • There is a 1-2% performance regression in the 23.10 Paxml container compared to the 23.08 Paxml container on A100s.
    • ViT has an 11% pretraining performance regression in the T5X container compared to the initial JAX-Toolbox release container.
    • The T5 model in T5x has a 6-9% performance regression.
  • The JAX, T5x, and Paxml containers are an early release and have some differences from the NGC TensorFlow container that will be addressed in future releases. Some differences include CUDA minor version, support for 3rd party network devices via HPC-X, and versions of CTK/cuDNN/NCCL and so on. Some packages that do not currently ship include:
    • JupyterLab including Jupyter-TensorBoard
    • NVIDIA TensorRT
    • TensorBoard
    • NVIDIA HPC-X with UCX and OpenMPI
    • OpenMPI
    • NVIDIA RAPIDS
    • cuTENSOR
    • GDRCopy
  • The jax.experimental.sparse sparse matrix package currently produces incorrect matrix multiplication results on NVIDIA Ampere architecture/Hopper GPUs. Volta GPUs appear to not be affected. See detailed issue here.
  • There are known CVEs that affect the amd64 Paxml container related to TensorFlow 2.9.x due to pinning TensorFlow to 2.9.x in Paxml and Lingvo. We will fix these in an upcoming release. The known CVEs are:
© Copyright 2024, NVIDIA. Last updated on Apr 5, 2024.