POC for creating reproducible jax-based python environments for MaxText
Some non-obvious notes
- The general idea of the approach is to rely on already vast and well tested locked set of deps that JAX itself uses for a big chunk of its own testing (including presubmit and continuous jobs). Propagating those dependencies inside MaxText environment ensures that for a specific JAX version MaxText runs itself on as close environment as possible to what JAX of the same version was testing itself when it was getting released.
- To get an idea how this all works under the hood check the
build_seed_env.sh, it is pretty short and self explanatory for the most part. - CUDA deps are pulled as python wheels, which is the recommended for JAX to get CUDA, no system-wide cuda packages are needed except driver.
- Presense of libtpu in an env makes jax to assume that it must run on TPU, so for any GPU-based workflows libtpu must be excluded (thus the
constraints_tpu_only.txtfile) - CUDA wheels are big and heavy, installing for TPU workflows is an unnecessary waste of resources (thus the
constraints_gpu_only.txt).
Quick start
- Install uv.
- Always start in a directory with minimal
pyproject.toml(as it is in this repo), and nouv.lockfile present. - Run
./build_seed_env.sh - The script above will produce
maxtext_requirements_lock_3_12.txtwhich will contain a full set of locked maxtext python dependencies pinned to the highest version numbers available when you ran it. - Use
maxtext_requirements_lock_3_12.txtit to set up any virtual env or Docker container you want to run MaxText in. - Re-running
./build_seed_env.shat any future point in time is non-reproducible. - The script above also produces a
pyproject.toml, which lists same dependencies as in the lock.txt but in a lower-bound form. - The
pyproject.tomlshould be comitted in source tree every time it is updated (see step vam-google#1). - If
pyproject.tomlis comitted, runninguv export --managed-python --locked --no-hashes --no-annotate --resolution=lowest --output-file=maxtext_requirements_lock_3_12.txton that commit at any point in time in the future is reproducible. - MaxText may have different
pyproject.toml(in different folders), each corresponding to a specific workflow. - For any commit in MaxText (assuming
pyproject.tomlis checked in), use command vam-google#8 to recreate MaxText Python environment for that commit. - To generate
pyproject.tomlandrequirements_lock.txtfor a different python version changerequires-pythonline inpyproject.tomland pull matching jaxrequirements_lock_<py_ver>.txtinbuild_seed_env.shrepeat process from scratch (pyproject.toml should be with no deps and nouv.lockfile should be present). - TBD: Use pyproject.toml to generate MaxText meta wheel, with all its deps lower-bounded, but not pinned.
On this page
Languages
Shell62.3%Python37.7%
Contributors
Apache License 2.0
Created June 11, 2025
Updated June 11, 2025