# Common dependencies -r requirements-common.txt torch jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html flax >= 0.8