What can I help you with?
NVIDIA Optimized Frameworks

JAX Release 23.10

The NVIDIA container image for JAX, release 23.10 is available on NGC.

Contents of the JAX container

This container image contains the complete source of the version of JAX in /opt/jax-source. It is prebuilt and installed within the container image. The container also includes the following:

Driver Requirements

Release 23.10 is based on CUDA 12.2.0, which requires NVIDIA Driver release 535 or later. However, if you are running on a data center GPU (for example, T4 or any other data center GPU), you can use NVIDIA driver release 450.51 (or later R450), 470.57 (or later R470), 510.47 (or later R510), or 525.85 (or later R525), or 535.86 (or later R535). The CUDA driver's compatibility package only supports particular drivers. Thus, users should upgrade from all R418, R440, R460, and R520 drivers, which are not forward-compatible with CUDA 12.2. For a complete list of supported drivers, see the CUDA Application Compatibility topic. For more information, see CUDA Compatibility and Upgrades.

GPU Requirements

Release 23.10 supports CUDA compute capability 6.0 and later. This corresponds to GPUs in the NVIDIA Pascal, NVIDIA Volta™, NVIDIA Turing™, NVIDIA Ampere architecture, and NVIDIA Hopper™ architecture families. For a list of GPUs to which this compute capability corresponds, see CUDA GPUs. For additional support details, see Deep Learning Frameworks Support Matrix.

Key Features and Enhancements

This JAX release includes the following key features and enhancements.

Announcements

  • Transformer Engine is a library for accelerating Transformer models on NVIDIA GPUs. It includes support for 8-bit floating point (FP8) precision on Hopper GPUs which provides better training and inference performance with lower memory utilization. Transformer Engine also includes a collection of highly optimized modules for popular Transformer architectures and an automatic mixed precision-like API that can be used seamlessly with your JAX code.

NVIDIA JAX Container Versions

The following table shows what versions of CUDA and JAX are supported in each of the NVIDIA containers for JAX.

Container Version CUDA Toolkit JAX
23.10 NVIDIA CUDA 12.2.0 0.4.17
23.08 NVIDIA CUDA 12.1.1 0.4.14

NVIDIA Tx5 and Paxml Container Versions

The following table shows what versions of CUDA and JAX are supported in each of the NVIDIA containers for JAX.

  Container Version CUDA Toolkit JAX
Paxml 23.10 NVIDIA CUDA 12.2.2 0.4.17.dev20231010
23.08 NVIDIA CUDA 12.1.1 0.4.14 (579808d98)
T5x 23.10 NVIDIA CUDA 12.2.2 0.4.17.dev20231010
23.08 NVIDIA CUDA 12.1.1 0.4.14 (603eeb190)


The JAX version in the Paxml and T5x containers are development versions. See /opt/jax-source and /opt/xla-source for the version used in each container.

Inspecting Source code in NVIDIA T5x and Paxml Containers

If you would like to inspect the pax’s source code (paxml and praxis) to learn more about what is being run, you can do so by inspecting the source within the nvcr.io/nvidia/jax:23.10-paxml-py3 container. Their locations within the container are:

  • Paxml: /opt/paxml
  • Praxis: /opt/praxis

Similarly, for t5x’s source code in nvcr.io/nvidia/jax:23.10-t5x-py3:

  • t5x: /opt/t5x

JAX Toolbox Examples

The JAX Toolbox examples focus on achieving the best performance and convergence from NVIDIA Hopper and NVIDIA Ampere architecture tensor cores by using the latest deep learning example networks and model scripts for training. These examples are tested against a nightly CI as well as each NGC container release to ensure consistent accuracy and performance over time.

  • GPT model: This is a decoder-only LLM architecture used in many modern text applications.This model script is available on Github. You can also find the latest performance and secured convergence on the model card on Github.

  • T5 model: This model was introduced in the Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer paper. This is a popular model for sequence-to-sequence modeling tasks involving text.This model script is available on Github. You can also find the latest performance and secured convergence on the model card on Github.
  • ViT model: The Vision Transformer was introduced in the An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale paper. This is a popular computer vision model used for image classification, but is also a foundational model for GenAI applications such as Multimodal Language Models. The ViT model also showcases NVIDIA DALI data pipelines. This is particularly helpful for smaller models where data pipelines tend to be the bottleneck. The model script is available on Github. You can also find the latest performance and secured convergence on the model card on Github.

Known Issues

© Copyright 2023, NVIDIA. Last updated on Dec 20, 2023.