NVIDIA Optimized Frameworks

JAX Overview

JAX is a framework for high-performance numerical computing and machine learning research. It combines numpy-like APIs, automatic differentiation, XLA acceleration and simple primitives for scaling across GPUs.

The JAX NGC Container comes with all dependencies included, providing an easy place to start developing applications in areas such as NLP, computer vision, multimodality, physics-based simulations, reinforcement learning, drug discovery, and neural rendering.

The JAX NGC Container is optimized for GPU acceleration, and contains a validated set of libraries that enable and optimize GPU performance. This container may also include modifications to the JAX source code in order to maximize performance and compatibility. This container also includes software for accelerated Training (cuDNN, NCCL, TransformerEngine).

For working with neural networks, the JAX NGC Container includes Flax, a neural network library with support for common deep learning models, layers and optimizers. We also include an additional container for training LLMs with MaxText. You can use the JAX or MaxText containers for your deep learning workloads or install your own favorite libraries on top of them.

© Copyright 2024, NVIDIA. Last updated on Jan 29, 2025.