JAX Release 25.08
The NVIDIA JAX Release 25.08 is made up of two container images available on NGC: JAX and MaxText.
Contents of the JAX container
This container image contains the complete source for the following software:
- JAX:
/opt/jax
- XLA:
/opt/xla
- Flax:
/opt/flax
- TransformerEngine:
/opt/transformer-engine
The MaxText container image is based on the JAX container. Additionally, it includes:
- MaxText:
/opt/maxtext
The JAX runtime package jaxlib is prebuilt and installed in the default Python environment (/usr/local/lib/python3.10/dist-packages/jaxlib) in the container image.
Versions of packages included in both of these containers:
- JAX 0.6.2
- Flax 0.10.5
- maxtext.git@0b1fde4513d9 (only in the maxtext container
- Ubuntu 24.04 including Python 3.12
- NVIDIA CUDA® 13.0.0
- NVSHMEM 3.3.20
- NVIDIA cuBLAS 13.0.0.19
- NVIDIA cuDNN 9.12.0.46
- NVIDIA NCCL 2.27.7
- NVIDIA DALI 1.51.2
- rdma-core 50.0
- NVIDIA HPC-X 2.24
- GDRCopy 2.4.1
- Nsight Compute 2025.3.0.19
- Nsight Systems 2025.4.1.136
- TransformerEngine 2.5
Driver Requirements
Release 25.08 is based on CUDA 13.0.0 (Toolkit) which requires CUDA Driver release 580.65. Please refer to the latest Drivers and CTKsupport table for additional information.
For a complete list of supported drivers, see the CUDA Application Compatibility topic. For more information, see CUDA Compatibility and Upgrades.
Key Features and Enhancements
This JAX release includes the following key features and enhancements.
- The current release is based on JAX 0.6.2.
- Compatibility with CUDA 13.0
- New hardware: RTX PRO™ 6000 Blackwell Server Edition functional support.
- Max Text
- JAX SDPA API enabled for training and inference models.
-
Added Context Parallelism using TE, cuDNN’s Flash attention & all-gather mechanism to support long-context model training.
- Added CuDNN’s paged attention in jax._src_.cudnn package.
- Enabled multi-stream collective overlap in XLA.
- XLA overlaps host offloading of array slices with computation for efficient activation offloading.
- New tutorial on Ray-based resilient training in JAX Toolbox.
- The compiler now automatically overlaps host offloading of array slices with computation. The overlap enables efficient implementations of activation offloading.
- Added hermetic support for CUDA, CuDNN, NCCL and NVSHMEM in XLA.
JAX Toolbox
The JAX Toolbox projects focus on achieving the best performance and convergence on NVIDIA Ampere, Hopper, and Blackwell architecture families and provide the latest deep learning models and scripts for training and fine-tuning. These examples are tested against a nightly CI as well as each NGC container release to ensure consistent accuracy and performance over time.
Nightly Containers
In addition to projects, JAX Toolbox includes nightly containers for libraries across the JAX ecosystem.
Container | Type | Image URI |
---|---|---|
jax | - | ghcr.io/nvidia/jax:jax |
t5x | LLM framework | ghcr.io/nvidia/jax:t5x |
maxtext | LLM framework | ghcr.io/nvidia/jax:maxtext |
equinox | layer library | ghcr.io/nvidia/jax:equinox |
axlearn | LLM framework | ghcr.io/nvidia/jax:axlearn |
Known Issues
-
When running in a configuration with multiple GPUs per process, JAX can hang when loading a kernel to GPUs on the first run of the kernel. NVIDIA recommends running in a process-per-GPU configuration (this is the default when using SLURM automatic configuration).
Known Issues Fixed
- Numerical correctness problems in jax.nn.scaled_matmul and jax.nn.scaled_dot_general for MXFP8 and NVFP4
- Tensor memory deallocation in Mosaic GPU.