NVIDIA Optimized Frameworks

JAX Release 26.06

The NVIDIA JAX Release 26.06 is made up of two container images available on NGC: JAX and MaxText.

Contents of the JAX container

This container image contains the complete source for the following software:

  • JAX: /opt/jax
  • XLA: /opt/xla
  • Flax:/opt/flax
  • TransformerEngine:/opt/transformer-engine

The MaxText container image is based on the JAX container. Additionally, it includes:

  • MaxText: /opt/maxtext

The JAX runtime package jaxlib is prebuilt and installed in the default Python environment (/usr/local/lib/python3.10/dist-packages/jaxlib) in the container image.

Versions of packages included in both of these containers:

  • CUDA 13.3.0
    • Please refer to the CUDA DL 26.06 release notes section for the list of libraries inherited from the CUDA container.

Driver Requirements

Release 26.06 is based on CUDA 13.3.0. For comprehensive and up-to-date driver compatibility information, please refer to the following documentation:

Key Features and Enhancements

New Model Support

  • Added support for Kimi K2-Thinking, K2.5, and K2.6 (text) models, including checkpoint conversion scripts.
  • Added Qwen3-30B-A3B-Base tokenizer and Qwen3.5 text-only decoder layer.
  • Added OLMo 3 7B/32B HuggingFace configs, stage-1 pretraining scripts, and numpy pretrain data pipeline.
  • Extended Gemma4 support: HuggingFace checkpoint conversion, vLLM adapter, layer-wise unit tests, MoE inference performance improvements, and multimodal evaluation (ChartQA).
  • Implemented DeepSeek, Gemma3, and Llama4 decoder layers in NNX.

Post-Training

  • Added standalone DPO training with configurable data hooks.
  • Added Generalized Learn-to-Init (LTI) for knowledge distillation; supports distillation of Llama attention layers
  • Added offline distillation with teacher forward pass support under sequence packing.
  • Refactored SFT with shared post-training hooks for reuse across training modes.
  • Integrated Multi-Token Prediction (MTP) with batch-split config.
  • Added per-step TFLOPs logging during distillation (student forward+backward and teacher forward).

MoE Performance

  • Added FP8 Megablox support for batch-split MoE.
  • Integrated ragged gather reduce kernel into MoE routing, improving efficiency over the prior scatter-based approach; added ragged sort for A2A expert parallelism.
  • Added TP-fused MoE and MLP dim padding for Tensor Parallelism.

Attention

  • Added out_sharding and qkv_shardingto MultiHeadAttention for explicit output sharding control.

Long-Context and Parallelism

  • Introduced CP-as-EP parallelism rule for long-context training and strong scaling.
  • Added 2D FSDP custom mesh and Zero1 AOT compilation support.
  • Added MLP dim padding for Tensor Parallelism.

Elastic Training

  • Added Grain elastic checkpointing support.
  • Added elastic replica resize capability.
  • Added goodput elastic event recording.

Checkpointing and Data

  • Added checkpoint resharding script and TFDS dataset resharding utility for large-scale training.
  • Improved HuggingFace checkpoint conversion for Gemma4 and Qwen3-MoE.
  • Added secure deserialization and checkpoint loading pipelines.
  • Added OLMo grain data pipeline for numpy pretrain mixes.

Compiler and Runtime

  • Upgraded JAX to 0.10.0.
  • Pallas/Triton kernels now compiled to PTX by default, improving portability and load times.
  • Added CUDA 13 build target.
  • MLIR-to-LLVM IR compilation is now asynchronous for MlirKernelFusion.
  • Improved GPU dot fusion cost model and introduced tileable GemmFusion.

Evaluation and Tooling

  • Added vLLM-based offline Eval Framework.
  • Added nnx.EMA (Exponential Moving Average) module.
  • Added kwargs pass-through to base module in nnx.LoRA.
  • NNX training infrastructure: added TrainState, model creation utilities, and training loop support.
  • Added "From PyTorch to JAX and Flax" migration guide.

JAX Toolbox

The JAX Toolbox projects focus on achieving the best performance and convergence on NVIDIA Ampere, Hopper, and Blackwell architecture families and provide the latest deep learning models and scripts for training and fine-tuning. These examples are tested against a nightly CI as well as each NGC container release to ensure consistent accuracy and performance over time.

Nightly Containers

In addition to projects, JAX Toolbox includes nightly containers for libraries across the JAX ecosystem.

ContainerTypeImage URI
jax-ghcr.io/nvidia/jax:jax-YYYY-MM-DD
maxtextLLM frameworkghcr.io/nvidia/jax:maxtext-YYYY-MM-DD
axlearnLLM frameworkghcr.io/nvidia/jax:axlearn-YYYY-MM-DD

Known Issues

  • None.

© Copyright 2026, NVIDIA. Last updated on Jun 29, 2026