JAX Release 25.10
The NVIDIA JAX Release 25.10 is made up of two container images available on NGC: JAX and MaxText.
Contents of the JAX container
This container image contains the complete source for the following software:
- JAX:
/opt/jax - XLA:
/opt/xla - Flax:
/opt/flax - TransformerEngine:
/opt/transformer-engine
The MaxText container image is based on the JAX container. Additionally, it includes:
- MaxText:
/opt/maxtext
The JAX runtime package jaxlib is prebuilt and installed in the default Python environment (/usr/local/lib/python3.10/dist-packages/jaxlib) in the container image.
Versions of packages included in both of these containers:
- JAX 0.7.2
- XLA (@0fccb8a)
- TransformerEngine 2.8
- Flax 0.11.2
- Maxtext (@8fdac10)
- Ubuntu 24.04 including Python 3.12
- NVIDIA CUDA® 13.0.2
- NVSHMEM 3.3.20
- NVIDIA cuBLAS 13.1.0.3
- NVIDIA cuDNN 9.14.0.64
- NVIDIA NCCL 2.27.7
- NVIDIA DALI 1.51.2
- rdma-core 50.0
- NVIDIA HPC-X 2.24
- GDRCopy 2.4.1
- Nsight Compute 2025.3.1.4
- Nsight Systems 2025.5.1.121
Driver Requirements
Release 25.10 is based on CUDA 13.0.2 (Toolkit) which requires CUDA Driver release 580.65. Please refer to the latest Drivers and CTKsupport table for additional information.
For a complete list of supported drivers, see the CUDA Application Compatibility topic. For more information, see CUDA Compatibility and Upgrades.
Key Features and Enhancements
This JAX release includes the following key features and enhancements.
- The current release is based on JAX 0.7.2 and CUDA 13.0.2.
- G/B300 performant support.
-
JAX comes with the Shardy partitioner on by default. See Shardy Migration guide for more details or instructions on how to disable the new partitioner.
- Fixed global scale application in NVFP4 quantization.
- Added and improved Shardy sharding rules for JAX SDPA and JAX scaled matmul APIs.
- Enabled XLA support for symmetric memory kernels in NCCL 2.27 collective operations. Use the flag
–xla_gpu_experimental_enable_nccl_symmetric_buffersto enable it. Symmetric memory kernels enable an improvement in NCCL communication latencies while reducing the number of SMs used for collectives, thus allowing for better overlap with compute operations. -
Transformer Engine added fused swizzling operation for the scaling factor inverse and transpose calculation of the data.
JAX Toolbox
The JAX Toolbox projects focus on achieving the best performance and convergence on NVIDIA Ampere, Hopper, and Blackwell architecture families and provide the latest deep learning models and scripts for training and fine-tuning. These examples are tested against a nightly CI as well as each NGC container release to ensure consistent accuracy and performance over time.
Nightly Containers
In addition to projects, JAX Toolbox includes nightly containers for libraries across the JAX ecosystem.
| Container | Type | Image URI |
|---|---|---|
| jax | - | ghcr.io/nvidia/jax:jax |
| t5x | LLM framework | ghcr.io/nvidia/jax:t5x |
| maxtext | LLM framework | ghcr.io/nvidia/jax:maxtext |
| equinox | layer library | ghcr.io/nvidia/jax:equinox |
| axlearn | LLM framework | ghcr.io/nvidia/jax:axlearn |
Known Issues
When running in a configuration with multiple GPUs per process, JAX can hang when loading a kernel to GPUs on the first run of the kernel. This has been mitigated by preloading Transformer Engine kernels. NVIDIA recommends running in a process-per-GPU configuration (this is the default when using SLURM automatic configuration).