NVIDIA Optimized Frameworks

JAX Release 26.01

The NVIDIA JAX Release 26.01 is made up of two container images available on NGC: JAX and MaxText.

Contents of the JAX container

This container image contains the complete source for the following software:

  • JAX: /opt/jax
  • XLA: /opt/xla
  • Flax:/opt/flax
  • TransformerEngine:/opt/transformer-engine

The MaxText container image is based on the JAX container. Additionally, it includes:

  • MaxText: /opt/maxtext

The JAX runtime package jaxlib is prebuilt and installed in the default Python environment (/usr/local/lib/python3.10/dist-packages/jaxlib) in the container image.

Versions of packages included in both of these containers:

  • CUDA 13.1.1
    • Please refer to the CUDA DL 26.01 release notes section for the list of libraries inherited from the CUDA container.

Driver Requirements

Release 26.01 is based on CUDA 13.1 U1 (Toolkit) which requires CUDA Driver release 590.48. Please refer to the latest Drivers and CTKsupport table for additional information.

For a complete list of supported drivers, see the CUDA Application Compatibility topic. For more information, see CUDA Compatibility and Upgrades.

Key Features and Enhancements

This JAX release includes the following key features and enhancements.

  • The current release is based on JAX 0.8.1 and CUDA 13.1 U1
  • Added support for DGX Spark and Jetson Thor
  • Added NVLink domain capability for NVSHMEM to accelerate long-context training in multi-node multi-gpu environments. Use --xla_gpu_experimental_enable_nvshmem=true flag.
  • Fixed problem with running in a configuration with multiple GPUs per process where JAX would hang when loading a kernel to GPUs on the first run of the kernel.
  • Enabled boolean, int8 and uint8 datatypes in NVSHMEM reduction collectives.
  • Improved collective optimization pipeline to enable all-reduces to lower into efficient reduce-scatters.
  • Optimizations for sub-byte data types:
    • Generate hardware intrinsics for efficient conversion of fp4 types.
    • Optimized communication collective operation on sub-byte types such as int4 and fp4 eliminating redundant cast operations and reducing the volume of data moved.
  • Enable int4 in cuDNN Gemm fusions. Set flag xla_gpu_cudnn_gemm_fusion_level > 0to use.
  • Support forward convolution with dilation and added heuristic to differentiate between forward and backward convolution and dispatch to optimized cuDNN kernels.

JAX Toolbox

The JAX Toolbox projects focus on achieving the best performance and convergence on NVIDIA Ampere, Hopper, and Blackwell architecture families and provide the latest deep learning models and scripts for training and fine-tuning. These examples are tested against a nightly CI as well as each NGC container release to ensure consistent accuracy and performance over time.

Nightly Containers

In addition to projects, JAX Toolbox includes nightly containers for libraries across the JAX ecosystem.

ContainerTypeImage URI
jax-ghcr.io/nvidia/jax:jax-YYYY-MM-DD
maxtextLLM frameworkghcr.io/nvidia/jax:maxtext-YYYY-MM-DD
equinoxlayer libraryghcr.io/nvidia/jax:equinox-YYYY-MM-DD
axlearnLLM frameworkghcr.io/nvidia/jax:axlearn-YYYY-MM-DD

Known Issues

There are no known issues.

© Copyright 2026, NVIDIA. Last updated on Jan 30, 2026