# Common dependencies
-r requirements-common.txt

torch
jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
flax >= 0.8
