Add JAX to requirements.txt

This commit is contained in:
Woosuk Kwon 2024-04-01 08:23:59 +00:00
parent 38e3d33a62
commit 6894d3efef

View File

@ -4,6 +4,7 @@ sentencepiece # Required for LLaMA tokenizer.
numpy
torch ~= 2.2.0
torch_xla[tpu] ~= 2.2.0
jax[tpu] # Required for Pallas kernels.
requests
py-cpuinfo
transformers >= 4.39.1 # Required for StarCoder2 & Llava.