What can I help you with?
NVIDIA Optimized Frameworks

JAX Release 23.08

The NVIDIA container image for JAX, release 23.08 is available on NGC.

Contents of the JAX container

This container image contains the complete source of the version of JAX in /opt/jax-source. It is prebuilt and installed within the container image.

The container also includes the following:

Driver Requirements

Release 23.08 is based on CUDA 12.1.1, which requires NVIDIA Driver release 525.60.13 or later. However, if you are running on a data center GPU (for example, T4 or any other data center GPU), you can use NVIDIA driver release 450.51 (or later R450), 470.57 (or later R470), 510.47 (or later R510), or 525.85 (or later R525). The CUDA driver's compatibility package only supports particular drivers. Thus, users should upgrade from all R418, R440, R460, and R520 drivers, which are not forward-compatible with CUDA 12.1. For a complete list of supported drivers, see the CUDA Application Compatibility topic. For more information, see CUDA Compatibility and Upgrades.

GPU Requirements

Release 23.08 supports CUDA compute capability 6.0 and later. This corresponds to GPUs in the NVIDIA Pascal, NVIDIA Volta™, NVIDIA Turing™, NVIDIA Ampere architecture, and NVIDIA Hopper™ architecture families. For a list of GPUs to which this compute capability corresponds, see CUDA GPUs. For additional support details, see Deep Learning Frameworks Support Matrix.

Key Features and Enhancements

This JAX release includes the following key features and enhancements.

  • JAX container images in version 23.08 are based on jaxlib==0.4.14.
  • T5X and PAXML containers are available in version 23.08.

Announcements

  • Transformer Engine is a library for accelerating Transformer models on NVIDIA GPUs. It includes support for 8-bit floating point (FP8) precision on Hopper GPUs which provides better training and inference performance with lower memory utilization. Transformer Engine also includes a collection of highly optimized modules for popular Transformer architectures and an automatic mixed precision-like API that can be used seamlessly with your JAX code.

NVIDIA JAX Container Versions

The following table shows what versions of CUDA and JAX are supported in each of the NVIDIA containers for JAX.

Container VersionCUDA ToolkitJAX
23.08NVIDIA CUDA 12.1.10.4.14

NVIDIA Tx5 and Paxml Container Versions

The following table shows what versions of CUDA and JAX are supported in each of the NVIDIA containers for JAX.

 Container VersionCUDA ToolkitJAX
Paxml23.08NVIDIA CUDA 12.1.10.4.14 (579808d98)
T5x23.08NVIDIA CUDA 12.1.10.4.14 (603eeb190)

The JAX version in the Paxml and T5x containers are development versions. See /opt/jax-source and /opt/xla-source for the version used in each container.

Inspecting Source code in NVIDIA T5x and Paxml Containers

If you would like to inspect the pax’s source code (paxml and praxis) to learn more about what is being run, you can do so by inspecting the source within the nvcr.io/nvidia/jax:23.08-paxml-py3 container. Their locations within the container are:

  • Paxml: /opt/paxml
  • Praxis: /opt/praxis

Similarly, for t5x’s source code in nvcr.io/nvidia/jax:23.08-t5x-py3:

  • t5x: /opt/t5x

JAX Toolbox Examples

The JAX Toolbox examples focus on achieving the best performance and convergence from NVIDIA Hopper and NVIDIA Ampere architecture tensor cores by using the latest deep learning example networks and model scripts for training.

These examples are tested against a nightly CI as well as each NGC container release to ensure consistent accuracy and performance over time.

Known Issues

© Copyright 2025, NVIDIA. Last updated on May 30, 2025.