diff --git a/requirements-tpu.txt b/requirements-tpu.txt index 624880ffa00e4..e627bb9b9e6b5 100644 --- a/requirements-tpu.txt +++ b/requirements-tpu.txt @@ -1,2 +1,5 @@ # Common dependencies -r requirements-common.txt + +jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +flax >= 0.8