JAX Overview

JAX Overview (PDF)

JAX is a framework for high-performance numerical computing and machine learning research. It combines numpy-like APIs, automatic differentiation, XLA acceleration and simple primitives for scaling across GPUs.

The JAX NGC Container comes with all dependencies included, providing an easy place to start developing applications in areas such as NLP, computer vision, multimodality, physics-based simulations, reinforcement learning, drug discovery, and neural rendering.

The JAX NGC Container is optimized for GPU acceleration, and contains a validated set of libraries that enable and optimize GPU performance. This container may also include modifications to the JAX source code in order to maximize performance and compatibility. This container also includes software for accelerated Training (cuDNN, NCCL, TransformerEngine).

For working with neural networks, the JAX NGC Container includes Flax, a neural network library with support for common deep learning models, layers and optimizers. We also include additional containers for training LLMs with Paxml and MaxText. You can use the JAX, Paxml, or MaxText containers for your deep learning workloads or install your own favorite libraries on top of them.

© Copyright 2024, NVIDIA. Last updated on Apr 29, 2024.