What can I help you with?
NVIDIA Optimized Frameworks

JAX Release 25.01

The NVIDIA JAX Release 25.01 is made up of two container images available on NGC: JAX and MaxText.

Contents of the JAX container

This container image contains the complete source for the following software:

  • JAX: /opt/jax
  • XLA: /opt/xla
  • Flax:/opt/flax
  • TransformerEngine:/opt/transformer-engine

The MaxText container image is based on the JAX container. Additionally, it includes:

  • MaxText: /opt/maxtext

The JAX runtime package jaxlib is prebuilt and installed in the default Python environment (/usr/local/lib/python3.10/dist-packages/jaxlib) in the container image.

Versions of packages included in both of these containers:

Driver Requirements

Release 25.01 is based on CUDA 12.8.0 which requires NVIDIA Driver release 570 or later. However, if you are running on a data center GPU (for example, B100, L40, or any other data center GPU), you can use NVIDIA driver release 470.57 (or later R470), 525.85 (or later R525), 535.86 (or later R535), or 550.54 (or later R550) in forward-compatibility mode.

The CUDA driver's compatibility package only supports particular drivers. Thus, users should upgrade from all R418, R440, R450, R460, R510, R520, R530, R545 and R555 and R560 drivers, which are not forward-compatible with CUDA 12.8. For a complete list of supported drivers, see the CUDA Application Compatibility topic. For more information, see CUDA Compatibility and Upgrades.

Key Features and Enhancements

This JAX release includes the following key features and enhancements.

  • JAX container images in version 25.01 are based on jaxlib==0.4.38.
  • Added Blackwell GPU Architecture support.
  • Experimental support for the MXFP8 dtype.

  • Experimental support and testing for AWS networking. H100 instances on AWS (P5) have been evaluated. For optimal performance in LLM training and other distributed workloads with high communication costs, NVIDIA recommends the following:
    • AWS: the NCCL plugin supporting AWS EFA is included in the container and will be enabled automatically.

JAX Toolbox

The JAX Toolbox projects focus on achieving the best performance and convergence on NVIDIA Ampere, Hopper, and Blackwell architecture families and provide the latest deep learning models and scripts for training and fine-tuning. These examples are tested against a nightly CI as well as each NGC container release to ensure consistent accuracy and performance over time.

Nightly Containers

In addition to projects, JAX Toolbox includes nightly containers for libraries across the JAX ecosystem.

Container TypeImage URI
jax-ghcr.io/nvidia/jax:jax
t5xLLM frameworkghcr.io/nvidia/jax:t5x
levanterLLM frameworkghcr.io/nvidia/jax:levanter
maxtextLLM frameworkghcr.io/nvidia/jax:maxtext
triton JAX extensionghcr.io/nvidia/jax:triton
equinoxlayer libraryghcr.io/nvidia/jax:equinox
grokmodelghcr.io/nvidia/jax:grok

Container Type Image URI
jax - ghcr.io/nvidia/jax:jax
t5x LLM framework ghcr.io/nvidia/jax:t5x
levanter LLM framework ghcr.io/nvidia/jax:levanter
maxtext LLM framework ghcr.io/nvidia/jax:maxtext
triton JAX extension ghcr.io/nvidia/jax:triton
equinox layer library ghcr.io/nvidia/jax:equinox
grok model ghcr.io/nvidia/jax:grok

Known Issues

  • AWS EFA plugin crashes in 1 case with 256 GPUs. You can use this environment variable to WAR the bug: NCCL_RUNTIME_CONNECT=0 .
  • This version of XLA can hang during compilation in rare cases. You can use this environment variable to work around the issue for now: XLA_FLAGS=--xla_gpu_shard_autotuning=false.
  • On gamer Blackwell(sm_120) GPUs, JAX random number generator is non-deterministic..
© Copyright 2025, NVIDIA. Last updated on Apr 9, 2025.