NVIDIA Optimized Frameworks
NVIDIA Optimized Frameworks (Latest Release) Download PDF

JAX Release 24.10

The NVIDIA container image for JAX, release 24.10 is available on NGC.

Contents of the JAX container

This container image contains the complete source for the following software:

  • JAX: /opt/jax
  • XLA: /opt/xla
  • Flax:/opt/flax
  • TransformerEngine:/opt/transformer-engine

Additionally, the max container image tag includes:

  • MaxText: /opt/maxtext

The JAX runtime package jaxlib is prebuilt and installed in the default Python environment (/usr/local/lib/python3.10/dist-packages/jaxlib) in the container image. The container also includes the following:

Driver Requirements

Release 24.10 is based on CUDA 12.6.2 which requires NVIDIA Driver release 560 or later. However, if you are running on a data center GPU (for example, T4 or any other data center GPU), you can use NVIDIA driver release 470.57 (or later R470), 525.85 (or later R525), 535.86 (or later R535), or 545.23 (or later R545).

The CUDA driver's compatibility package only supports particular drivers. Thus, users should upgrade from all R418, R440, R450, R460, R510, R520, R530, R545 and R555 drivers, which are not forward-compatible with CUDA 12.6. For a complete list of supported drivers, see the CUDA Application Compatibility topic. For more information, see CUDA Compatibility and Upgrades.

Key Features and Enhancements

This JAX release includes the following key features and enhancements.

  • JAX container images in version 24.10 are based on jaxlib==0.4.33.
  • Experimental support and testing for AWS and GCP networking. H100 instances on both AWS (P5) and GCP (A3-Mega) have been evaluated. For optimal performance in LLM training and other distributed workloads with high communication costs, NVIDIA recommends the following:
    • AWS: Run the script at /usr/local/bin/install-efa.sh or include it in a new Dockerfile to leverage AWS EFA.
    • GCP: Follow the guide to set up a cluster and enable GCP's GPUDirect-TCPXO NCCL plugin.

JAX Toolbox

The JAX Toolbox projects focus on achieving the best performance and convergence on NVIDIA Hopper and Ampere tensor cores and provide the latest deep learning models and scripts for training and fine-tuning. These examples are tested against a nightly CI as well as each NGC container release to ensure consistent accuracy and performance over time. Projects

  • GPT: Decoder-only LLM used in many modern text applications. You can also find the latest performance and convergence on the model card.
  • LLaMA-2: Decoder-only LLM ranging in scale from 7B to 70B parameters. You can find the latest performance and convergence on our model card.

    • This release also introduces LoRA support for fine-tuning of LLaMA models. NVIDIA includes an example script that finetunes LLama 2 7B on BoolQ and achieves significant metric improvements. You can find the convergence and performance results on the above model card.
    • LLaMA-2 recipes are also available in maxtext. You can find instructions and performance results on our model card.
  • MoE: GLaM style MoE with configs up to 1.14T parameter scale. You can find the latest information on our model card.
  • T5: This model was introduced in the Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer paper. This is a popular model for sequence-to-sequence modeling tasks. You can find the latest performance and convergence on our model card.
  • ViT: The Vision Transformer was introduced in the An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale paper. This is a popular computer vision model used for image classification, but is also a foundational model for GenAI applications such as Multimodal Language Models. The ViT model also showcases NVIDIA DALI data pipelines. This is particularly helpful for smaller models where data pipelines tend to be the bottleneck.
  • Imagen: Text-to-image generative diffusion model based on Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding paper. You can find the latest performance and convergence on our model card.
  • PaliGemma: PaliGemma is a vision language model (VLM). These models are well-suited for a variety of tasks that require visually-situated text understanding and object localization.

Nightly Containers

In addition to projects, JAX Toolbox includes nightly containers for libraries across the JAX ecosystem.

Container Type Image URI
jax - ghcr.io/nvidia/jax:jax
paxml LLM framework ghcr.io/nvidia/jax:pax
t5x LLM framework ghcr.io/nvidia/jax:t5x
levanter LLM framework ghcr.io/nvidia/jax:levanter
maxtext LLM framework ghcr.io/nvidia/jax:maxtext
triton JAX extension ghcr.io/nvidia/jax:triton
equinox layer library ghcr.io/nvidia/jax:equinox
grok model ghcr.io/nvidia/jax:grok

Known Issues

  • A JAX eigensolver unit test is failing due to a flaw in the test itself using a hard-coded workspace size. This should not impact real applications, where the workspace size is computed by the cuSolver API.
  • There is a known convergence issue with GPT 126M when Flash Attention is enabled.
  • There is a known divergence issue with Flash Attention in PAXML when enabled natively via XLA without Transformer Engine. If training models with TransformerEngine disabled/absent, NVIDIA recommends running without Flash Attention by passing--xla_gpu_enable_cudnn_fmha=false.
  • LLaMA fine-tuning via Maxtext does not currently support FP8.
  • Transformer Engine is currently not supported with GLaM models. Future releases will include TE support with GLaM.
  • There is a known issue in the Maxtext container where failures may be observed if --xla_gpu_enable_custom_fusions and --xla_gpu_enable_address_computation_fusion flags are enabled. It is advised to disable these two flags by adding --xla_gpu_enable_custom_fusions=false and --xla_gpu_enable_address_computation_fusion=false to the XLA_FLAGS environment variable.
© Copyright 2024, NVIDIA. Last updated on Dec 2, 2024.