diff --git a/requirements-tpu.txt b/requirements-tpu.txt index e627bb9b9e6b5..395e362bba3eb 100644 --- a/requirements-tpu.txt +++ b/requirements-tpu.txt @@ -1,5 +1,6 @@ # Common dependencies -r requirements-common.txt +torch jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html flax >= 0.8