GitHunt
MI

MichaelHudgins/jax-ai-stack

JAX AI Stack

Continuous integration
PyPI version
Documentation

JAX is a Python package for array-oriented
computation and program transformation. Built around it is a growing ecosystem
of packages for specialized numerical computing across a range of domains; an
up-to-date list of such projects can be found at
Awesome JAX.

Though JAX is often compared to neural network libraries like PyTorch, the JAX
core package itself contains very little that is specific to neural network
models. Instead, JAX encourages modularity, where domain-specific libraries
are developed separately from the core package: this helps drive innovation
as researchers and other users explore what is possible.

Within this larger, distributed ecosystem, there are a number of projects that
Google researchers and engineers have found useful for implementing and deploying
the models behind generative AI tools like Imagen,
Gemini, and more. The JAX AI stack serves as a
single point-of-entry for this suite of libraries, so you can install and begin
using many of the same open source packages that Google developers are using
in their everyday work.

To get started with the JAX AI stack, you can check out Getting started with JAX.
This is still a work-in-progress, please check back for more documentation and tutorials
in the coming weeks!

Installing the stack

The stack can be installed with the following command:

pip install jax-ai-stack

This pins particular versions of component projects which are known to work correctly
together via the integration tests in this repository. Packages include:

  • JAX: the core JAX package, which includes array operations
    and program transformations like jit, vmap, grad, etc.
  • flax: build neural networks with JAX
  • ml_dtypes: NumPy dtype extensions for machine learning.
  • optax: gradient processing and optimization in JAX.
  • orbax: checkpointing and persistence utilities for JAX.
  • chex: utilities for writing reliable JAX code.
  • grain: data loading.

Optional packages

Additionally, there are optional packages you can install with pip extras.

The following command:

pip install jax-ai-stack[tfds]

will install a compatible version of
tensorflow
and tensorflow-datasets.

Hardware support

To install jax-ai-stack with hardware-specific JAX support, add the JAX installation
command in the same pip install invocation. For example:

pip install jax-ai-stack "jax[cuda]"  # JAX + AI stack with GPU/CUDA support
pip install jax-ai-stack "jax[tpu]"  # JAX + AI stack with TPU support

For more information on available options for hardware-specific JAX installation, refer
to JAX installation.

Apache License 2.0
Created August 14, 2025
Updated August 14, 2025
MichaelHudgins/jax-ai-stack | GitHunt