JAX Release 23.08

JAX Release 23.08 (PDF)

The NVIDIA container image for JAX, release 23.08 is available on NGC.

Contents of the JAX container

This container image contains the complete source of the version of JAX in /opt/jax-source. It is prebuilt and installed within the container image. The container also includes the following:

Driver Requirements

Release 23.08 is based on CUDA 12.1.1, which requires NVIDIA Driver release 525.60.13 or later. However, if you are running on a data center GPU (for example, T4 or any other data center GPU), you can use NVIDIA driver release 450.51 (or later R450), 470.57 (or later R470), 510.47 (or later R510), or 525.85 (or later R525). The CUDA driver's compatibility package only supports particular drivers. Thus, users should upgrade from all R418, R440, R460, and R520 drivers, which are not forward-compatible with CUDA 12.1. For a complete list of supported drivers, see the CUDA Application Compatibility topic. For more information, see CUDA Compatibility and Upgrades.

GPU Requirements

Release 23.08 supports CUDA compute capability 6.0 and later. This corresponds to GPUs in the NVIDIA Pascal, NVIDIA Volta™, NVIDIA Turing™, NVIDIA Ampere architecture, and NVIDIA Hopper™ architecture families. For a list of GPUs to which this compute capability corresponds, see CUDA GPUs. For additional support details, see Deep Learning Frameworks Support Matrix.

Key Features and Enhancements

This JAX release includes the following key features and enhancements.

  • JAX container images in version 23.08 are based on jaxlib==0.4.14.
  • T5X and PAXML containers are available in version 23.08.


  • Transformer Engine is a library for accelerating Transformer models on NVIDIA GPUs. It includes support for 8-bit floating point (FP8) precision on Hopper GPUs which provides better training and inference performance with lower memory utilization. Transformer Engine also includes a collection of highly optimized modules for popular Transformer architectures and an automatic mixed precision-like API that can be used seamlessly with your JAX code.

NVIDIA JAX Container Versions

The following table shows what versions of CUDA and JAX are supported in each of the NVIDIA containers for JAX.

Container Version CUDA Toolkit JAX
23.08 NVIDIA CUDA 12.1.1 0.4.14

NVIDIA Tx5 and Paxml Container Versions

The following table shows what versions of CUDA and JAX are supported in each of the NVIDIA containers for JAX.

  Container Version CUDA Toolkit JAX
Paxml 23.08 NVIDIA CUDA 12.1.1 0.4.14 (579808d98)
T5x 23.08 NVIDIA CUDA 12.1.1 0.4.14 (603eeb190)

The JAX version in the Paxml and T5x containers are development versions. See /opt/jax-source and /opt/xla-source for the version used in each container.

Inspecting Source code in NVIDIA T5x and Paxml Containers

If you would like to inspect the pax’s source code (paxml and praxis) to learn more about what is being run, you can do so by inspecting the source within the nvcr.io/nvidia/jax:23.08-paxml-py3 container. Their locations within the container are:

  • Paxml: /opt/paxml
  • Praxis: /opt/praxis

Similarly, for t5x’s source code in nvcr.io/nvidia/jax:23.08-t5x-py3:

  • t5x: /opt/t5x

JAX Toolbox Examples

The JAX Toolbox examples focus on achieving the best performance and convergence from NVIDIA Hopper and NVIDIA Ampere architecture tensor cores by using the latest deep learning example networks and model scripts for training. These examples are tested against a nightly CI as well as each NGC container release to ensure consistent accuracy and performance over time.

Known Issues

  • Pipeline parallelism is not supported with NVIDIA Transformer Engine enabled in the Paxml container.
  • There is a 15% performance regression on Paxml’s 126M parameter model, trained with BF16 on A100s, when NVIDIA Transformer Engine is disabled. This will be fixed in a future release.
  • The Paxml container (nvcr.io/nvidia/jax:23.08-paxml-py3) does not fully support Hopper yet. Future releases will add Hopper support.
  • There is a known sporadic NCCL crash that happens when using the T5x container at node counts greater than or equal to 32 nodes. We will fix this in a future release. The issue is tracked here.
  • The JAX, T5x, and Paxml containers are an early release and have some differences from the NGC TensorFlow container that will be addressed in future releases. Some differences include CUDA minor version, support for 3rd party network devices via HPC-X, and versions of CTK/cuDNN/NCCL and so on. Some packages that do not currently ship include:
    • JupyterLab including Jupyter-TensorBoard
    • NVIDIA TensorRT
    • TensorBoard
    • NVIDIA HPC-X with UCX and OpenMPI
    • OpenMPI
    • cuTENSOR
    • GDRCopy
  • The jax.experimental.sparse sparse matrix package currently produces incorrect matrix multiplication results on NVIDIA Ampere architecture/Hopper GPUs. Volta GPUs appear to not be affected. See detailed issue here.
  • There are known CVEs that affect the Paxml container related to TensorFlow 2.9.x due to pinning TensorFlow to 2.9.x in Paxml and Lingvo. We will fix these in an upcoming release. The known CVEs are:
© Copyright 2024, NVIDIA. Last updated on Apr 29, 2024.