JAX Overview

JAX Overview (PDF)

JAX is a framework for high-performance numerical computing and machine learning research. It includes numpy-like APIs, automatic differentiation, XLA acceleration and simple primitives for scaling across GPUs.

The JAX NGC Container comes with all dependencies included, providing an easy place to start developing applications in areas such as NLP, Computer Vision, Multimodality, physics-based simulations, reinforcement learning, drug discovery, and neural rendering.

The JAX NGC Container is optimized for GPU acceleration, and contains a validated set of libraries that enable and optimize GPU performance. This container may also include modifications to the JAX source code in order to maximize performance and compatibility. This container also includes software for accelerating ETL (DALI) and Training (cuDNN, NCCL).

For working with neural networks, the JAX NGC Container includes Flax, a neural network library with support for common deep learning models, layers and optimizers. We also include containers for training GPT models with a Paxml container, and training T5 and ViT models with our T5x container. You can use the JAX, Paxml, or T5x containers for your deep learning workloads or install your own favorite libraries on top of them.

© Copyright 2024, NVIDIA. Last updated on Apr 5, 2024.