mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-03 12:44:31 +08:00
[TPU] update torch_xla pin (#19231)
Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
parent
f8a1a2d108
commit
b61dc5f972
@ -18,9 +18,9 @@ setuptools==78.1.0
|
|||||||
--find-links https://storage.googleapis.com/libtpu-releases/index.html
|
--find-links https://storage.googleapis.com/libtpu-releases/index.html
|
||||||
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
|
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
|
||||||
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||||
torch==2.8.0.dev20250529
|
torch==2.8.0.dev20250605
|
||||||
torchvision==0.22.0.dev20250529
|
torchvision==0.23.0.dev20250605
|
||||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250529-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250605-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
||||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250529-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250605-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
||||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250529-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250605-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||||
|
|
||||||
|
|||||||
@ -27,7 +27,7 @@ TOP_KS = [2, 6]
|
|||||||
# The Pallas GMM kernel requires num_tokens * topk to be a multiple of 16
|
# The Pallas GMM kernel requires num_tokens * topk to be a multiple of 16
|
||||||
@pytest.mark.parametrize("m", [8, 16, 64, 2048])
|
@pytest.mark.parametrize("m", [8, 16, 64, 2048])
|
||||||
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
||||||
@pytest.mark.parametrize("k", [128, 512, 1024])
|
@pytest.mark.parametrize("k", [128, 511, 1024])
|
||||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||||
@pytest.mark.parametrize("topk", TOP_KS)
|
@pytest.mark.parametrize("topk", TOP_KS)
|
||||||
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
||||||
|
|||||||
@ -100,7 +100,8 @@ class TPUWorker:
|
|||||||
# `xla_tpu_force_1d_allreduce_at_chunk_count` is a temporary solution to
|
# `xla_tpu_force_1d_allreduce_at_chunk_count` is a temporary solution to
|
||||||
# fix this. It will be removed after the bug in XLA compiler is fixed.
|
# fix this. It will be removed after the bug in XLA compiler is fixed.
|
||||||
os.environ["LIBTPU_INIT_ARGS"] = (
|
os.environ["LIBTPU_INIT_ARGS"] = (
|
||||||
"--xla_tpu_force_1d_allreduce_at_chunk_count=1")
|
os.environ.get("LIBTPU_INIT_ARGS", "") +
|
||||||
|
" --xla_tpu_force_1d_allreduce_at_chunk_count=1")
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
torch.set_default_dtype(self.model_config.dtype)
|
torch.set_default_dtype(self.model_config.dtype)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user