What can I help you with?
NVIDIA Optimized Frameworks

JAX Release 25.04

The NVIDIA JAX Release 25.04 is made up of two container images available on NGC: JAX and MaxText.

Contents of the JAX container

This container image contains the complete source for the following software:

  • JAX: /opt/jax
  • XLA: /opt/xla
  • Flax:/opt/flax
  • TransformerEngine:/opt/transformer-engine

The MaxText container image is based on the JAX container. Additionally, it includes:

  • MaxText: /opt/maxtext

The JAX runtime package jaxlib is prebuilt and installed in the default Python environment (/usr/local/lib/python3.10/dist-packages/jaxlib) in the container image.

Versions of packages included in both of these containers:

Driver Requirements

Release 25.04 is based on CUDA 12.9.0 which requires NVIDIA Driver release 575 or later. However, if you are running on a data center GPU (for example, B100, L40, or any other data center GPU), you can use NVIDIA driver release 470.57 (or later R470), 525.85 (or later R525), 535.86 (or later R535), or 550.54 (or later R550) in forward-compatibility mode.

The CUDA driver's compatibility package only supports particular drivers. Thus, users should upgrade from all R418, R440, R450, R460, R510, R520, R530, R545 and R555 and R560 drivers, which are not forward-compatible with CUDA 12.9. For a complete list of supported drivers, see the CUDA Application Compatibility topic. For more information, see CUDA Compatibility and Upgrades.

Key Features and Enhancements

This JAX release includes the following key features and enhancements.

  • JAX container images in version 25.04 are based on jaxlib==0.5.3.
  • Experimental support for MXFP8 & FP4 at jax.nn.scaled_matmul.
  • Enhanced jax SDPA API to support sequence packing.

  • Added explicit collective grouping capability. It can be enabled with set_xla_metadata(_collectives_group=””, inlineable=”false”) context manager.

  • Enabled weight and optimizer state offloading. Reasonable memory saving observed with scan operation.
  • Added experimental compute_on API to assign compute/communication kernels on a specific stream along with shard_map. Placed it behind the experimental flag xla_gpu_experimental_stream_annotation.
  • Replaced stream capture-based graph creation with explicit graph construction for invoking cuDNN’s graph APIs.
  • NUMA-aware XLA delivers up-to 2x speedup on D2H/H2D transfers.
  • Improved sequence parallelism perf with CSE optimization on collective ops.
  • Enabled collective permute optimization to combine small NCCL messages for cpu overhead reduction.
  • Observed 10% (geomean) perf gain on MaxText’s LLM models over version 25.01 on B200.
  • New JAX_COMPILATION_CACHE_EXPECT_PGLE flag to leverage PGLE-optimized compilation caches when using AutoPGLE. Turning on this flag allows AutoPGLE to work in tandem with the NSight profiler.
  • nsys-jax is now pip installable

JAX Toolbox

The JAX Toolbox projects focus on achieving the best performance and convergence on NVIDIA Ampere, Hopper, and Blackwell architecture families and provide the latest deep learning models and scripts for training and fine-tuning. These examples are tested against a nightly CI as well as each NGC container release to ensure consistent accuracy and performance over time.

Nightly Containers

In addition to projects, JAX Toolbox includes nightly containers for libraries across the JAX ecosystem.

Container TypeImage URI
jax-ghcr.io/nvidia/jax:jax
t5xLLM frameworkghcr.io/nvidia/jax:t5x
levanterLLM frameworkghcr.io/nvidia/jax:levanter
maxtextLLM frameworkghcr.io/nvidia/jax:maxtext
equinoxlayer libraryghcr.io/nvidia/jax:equinox
gemmamodelghcr.io/nvidia/jax:gemma
axlearnLLM frameworkghcr.io/nvidia/jax:axlearn

Container Type Image URI
jax - ghcr.io/nvidia/jax:jax
t5x LLM framework ghcr.io/nvidia/jax:t5x
levanter LLM framework ghcr.io/nvidia/jax:levanter
maxtext LLM framework ghcr.io/nvidia/jax:maxtext
equinox layer library ghcr.io/nvidia/jax:equinox
gemma model ghcr.io/nvidia/jax:gemma
axlearn LLM framework ghcr.io/nvidia/jax:axlearn

Known Issues

  • In a single process/per gpu setup, jax.nn.scaled_matmul with MXFP8 may cause incorrect values. The suggested workaround is to use cublasLt from CUDA 12.8 or use 1 process per GPU.

  • Mosaic GPU kernels in this release do not deallocate tensor memory before kernel exit, which causes a runtime check failure that is enabled starting with CUDA 12.8.1. The issue was fixed in JAX upstream PR. We recommend Mosaic GPU users to use the nightly releases of JAX to get the latest fixes.
  • To set visible GPU devices, we recommend to set the "CUDA_VISIBLE_DEVICES" environment variable directly. The JAX "jax_cuda_visible_devices" config in Python is not a reliable way to set visible devices.'

Bug Fixes

  • Out-of-memory caused by incorrect handling of memory limit #23271.
  • Resolved crash with both inter-node FSDP and inter-node DDP enabled on Blackwell.
  • Resolved hangs related to running sharded autotuning with cached auto-tuning.
© Copyright 2025, NVIDIA. Last updated on May 1, 2025.