JAX Release 25.01
The NVIDIA JAX Release 25.01 is made up of two container images available on NGC: JAX and MaxText.
Contents of the JAX container
This container image contains the complete source for the following software:
- JAX:
/opt/jax
- XLA:
/opt/xla
- Flax:
/opt/flax
- TransformerEngine:
/opt/transformer-engine
The MaxText container image is based on the JAX container. Additionally, it includes:
- MaxText:
/opt/maxtext
The JAX runtime package jaxlib is prebuilt and installed in the default Python environment (/usr/local/lib/python3.10/dist-packages/jaxlib) in the container image.
Versions of packages included in both of these containers:
- JAX 0.4.38 (Includes release-specific patches and mxfp8.)
- Flax 0.10.2
- maxtext.git@4651cb3c73de (only in the maxtext container
- Ubuntu 24.04 including Python 3.12
- NVIDIA CUDA® 12.8.0.038
- NVIDIA cuBLAS 12.8.3.14
- NVIDIA cuDNN 9.7.0.66
- NVIDIA NCCL 2.25.1
- rdma-core 50.0
- NVIDIA HPC-X 2.21
- OpenMPI 4.1.7
- GDRCopy 2.3.1-1
- Nsight Compute 2025.1.0.14
- Nsight Systems 2024.2.6.2.225
- TransformerEngine 1.14
Driver Requirements
Release 25.01 is based on CUDA 12.8.0 which requires NVIDIA Driver release 570 or later. However, if you are running on a data center GPU (for example, B100, L40, or any other data center GPU), you can use NVIDIA driver release 470.57 (or later R470), 525.85 (or later R525), 535.86 (or later R535), or 550.54 (or later R550) in forward-compatibility mode.
The CUDA driver's compatibility package only supports particular drivers. Thus, users should upgrade from all R418, R440, R450, R460, R510, R520, R530, R545 and R555 and R560 drivers, which are not forward-compatible with CUDA 12.8. For a complete list of supported drivers, see the CUDA Application Compatibility topic. For more information, see CUDA Compatibility and Upgrades.
Key Features and Enhancements
This JAX release includes the following key features and enhancements.
- JAX container images in version 25.01 are based on
jaxlib==0.4.38
. - Added Blackwell GPU Architecture support.
-
Experimental support for the MXFP8 dtype.
- Experimental support and testing for AWS networking. H100 instances on AWS (P5) have been evaluated. For optimal performance in LLM training and other distributed workloads with high communication costs, NVIDIA recommends the following:
- AWS: the NCCL plugin supporting AWS EFA is included in the container and will be enabled automatically.
JAX Toolbox
The JAX Toolbox projects focus on achieving the best performance and convergence on NVIDIA Ampere, Hopper, and Blackwell architecture families and provide the latest deep learning models and scripts for training and fine-tuning. These examples are tested against a nightly CI as well as each NGC container release to ensure consistent accuracy and performance over time.
Nightly Containers
In addition to projects, JAX Toolbox includes nightly containers for libraries across the JAX ecosystem.
Container | Type | Image URI |
---|---|---|
jax | - | ghcr.io/nvidia/jax:jax |
t5x | LLM framework | ghcr.io/nvidia/jax:t5x |
levanter | LLM framework | ghcr.io/nvidia/jax:levanter |
maxtext | LLM framework | ghcr.io/nvidia/jax:maxtext |
triton | JAX extension | ghcr.io/nvidia/jax:triton |
equinox | layer library | ghcr.io/nvidia/jax:equinox |
grok | model | ghcr.io/nvidia/jax:grok |
Container | Type | Image URI |
---|---|---|
jax | - | ghcr.io/nvidia/jax:jax |
t5x | LLM framework | ghcr.io/nvidia/jax:t5x |
levanter | LLM framework | ghcr.io/nvidia/jax:levanter |
maxtext | LLM framework | ghcr.io/nvidia/jax:maxtext |
triton | JAX extension | ghcr.io/nvidia/jax:triton |
equinox | layer library | ghcr.io/nvidia/jax:equinox |
grok | model | ghcr.io/nvidia/jax:grok |
Known Issues
- AWS EFA plugin crashes in 1 case with 256 GPUs. You can use this environment variable to WAR the bug:
NCCL_RUNTIME_CONNECT=0
. - This version of XLA can hang during compilation in rare cases. You can use this environment variable to work around the issue for now:
XLA_FLAGS=--xla_gpu_shard_autotuning=false
. - On gamer Blackwell(sm_120) GPUs, JAX random number generator is non-deterministic..