JAX Release 24.04

NVIDIA Optimized Frameworks (Latest Release) Download PDF

The NVIDIA container image for JAX, release 24.04 is available on NGC.

Contents of the JAX container

This container image contains the complete source for the following software:

  • JAX: /opt/jax
  • XLA: /opt/xla
  • Flax: /opt/flax
  • TransformerEngine: /opt/transformer-engine

The JAX runtime package jaxlib is prebuilt and installed in the default Python environment (/usr/local/lib/python3.10/dist-packages/jaxlib) in the container image. The container also includes the following:

Driver Requirements

Release 23.10 is based on CUDA 12.4.1, which requires NVIDIA Driver release 545 or later. However, if you are running on a data center GPU (for example, T4 or any other data center GPU), you can use NVIDIA driver release 470.57 (or later R470), 525.85 (or later R525), 535.86 (or later R535), or 545.23 (or later R545).

The CUDA driver's compatibility package only supports particular drivers. Thus, users should upgrade from all R418, R440, R450, R460, R510, and R520 drivers, which are not forward-compatible with CUDA 12.3. For a complete list of supported drivers, see the CUDA Application Compatibility topic. For more information, see CUDA Compatibility and Upgrades.

Key Features and Enhancements

This JAX release includes the following key features and enhancements.

  • JAX container images in version 24.04 are based on jaxlib==0.4.26.
  • PAXML and Maxtext containers are available in version 24.04.

JAX Toolbox

The JAX Toolbox projects focus on achieving the best performance and convergence on NVIDIA Hopper and Ampere tensor cores and provide the latest deep learning models and scripts for training and fine-tuning. These examples are tested against a nightly CI as well as each NGC container release to ensure consistent accuracy and performance over time. Projects

  • GPT: Decoder-only LLM used in many modern text applications. You can also find the latest performance and convergence on the model card.
  • LLaMA-2: Decoder-only LLM ranging in scale from 7B to 70B parameters. You can find the latest performance and convergence on our model card.

    • This release also introduces LoRA support for fine-tuning of LLaMA models. We include an example script that finetunes LLama 2 7B on BoolQ and achieves significant metric improvements. You can find the convergence and performance results on the above model card.
    • LLaMA-2 recipes are also available in maxtext. You can find instructions and performance results on our model card.
  • MoE: GLaM style MoE with configs up to 1.14T parameter scale. You can find the latest information on our model card.
  • T5: This model was introduced in the Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer paper. This is a popular model for sequence-to-sequence modeling tasks. You can find the latest performance and convergence on our model card.
  • ViT: The Vision Transformer was introduced in the An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale paper. This is a popular computer vision model used for image classification, but is also a foundational model for GenAI applications such as Multimodal Language Models. The ViT model also showcases NVIDIA DALI data pipelines. This is particularly helpful for smaller models where data pipelines tend to be the bottleneck.
  • Imagen: Text-to-image generative diffusion model based on Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding paper. You can find the latest performance and convergence on our model card.

Nightly Containers

In addition to projects, JAX Toolbox includes nightly containers for libraries across the JAX ecosystem.

Container Type Image URI
jax - ghcr.io/nvidia/jax:jax
paxml LLM framework ghcr.io/nvidia/jax:pax
t5x LLM framework ghcr.io/nvidia/jax:t5x
levanter LLM framework ghcr.io/nvidia/jax:levanter
maxtext LLM framework ghcr.io/nvidia/jax:maxtext
triton JAX extension ghcr.io/nvidia/jax:triton
equinox layer library ghcr.io/nvidia/jax:equinox
grok model ghcr.io/nvidia/jax:grok

NVIDIA Tx5 and Paxml Container Versions

The following table shows what versions of CUDA and JAX are supported in each of the NVIDIA containers for JAX.

The JAX version in the Paxml and T5x containers are development versions. See /opt/jax-source and /opt/xla-source for the version used in each container.

Known Issues

  • A JAX eigensolver unit test is failing due to a flaw in the test itself using a hard-coded workspace size. This should not impact real applications, where the workspace size is computed by the cuSolver API.
  • There is a known convergence issue with GPT 126M when Flash Attention is enabled.
  • There is a known divergence issue with Flash Attention in PAXML when enabled natively via XLA without Transformer Engine. If training models with TransformerEngine disabled/absent, it is recommended to run without Flash Attention by passing--xla_gpu_enable_cudnn_fmha=false.
  • LLaMA fine-tuning via PAXML does not currently support FP8.
  • MoE does not currently support TransformerEngine.
  • There is a known CVE that affects the amd64 Paxml container related to TensorFlow 2.9.x due to pinning TensorFlow to 2.9.x in Paxml. We will fix these in an upcoming release.
  • There is a known issue in the Maxtext container where failures may be observed if --xla_gpu_enable_custom_fusions and --xla_gpu_enable_address_computation_fusion flags are enabled. It is advised to disable these two flags by adding --xla_gpu_enable_custom_fusions=false and --xla_gpu_enable_address_computation_fusion=false to the XLA_FLAGS environment variable.
© Copyright 2024, NVIDIA. Last updated on May 29, 2024.