diff --git a/requirements/tpu.txt b/requirements/tpu.txt index 105c45db70df9..75ebbc4ed9403 100644 --- a/requirements/tpu.txt +++ b/requirements/tpu.txt @@ -17,10 +17,10 @@ ray[data] --find-links https://storage.googleapis.com/libtpu-releases/index.html --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250406-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" -torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250406-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" -torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250406-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250406-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250406-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250406-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" +torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250408-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" +torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250408-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" +torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250408-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" diff --git a/tests/v1/tpu/test_pallas.py b/tests/v1/tpu/test_pallas.py index 54eab145efb47..8faa5270b5930 100644 --- a/tests/v1/tpu/test_pallas.py +++ b/tests/v1/tpu/test_pallas.py @@ -4,9 +4,7 @@ from unittest.mock import ANY, patch import torch from vllm.attention.backends.abstract import AttentionType -from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK, - NUM_QUERIES_PER_BLOCK, - PallasAttentionBackendImpl, +from vllm.v1.attention.backends.pallas import (PallasAttentionBackendImpl, PallasMetadata) @@ -32,8 +30,6 @@ def test_ragged_paged_attention(): logits_soft_cap=logits_soft_cap, attn_type=AttentionType.DECODER, ) - mock_vmem_limit_bytes = 1024 - attn_impl.vmem_limit_bytes = mock_vmem_limit_bytes class FakeAttentionLayer: _k_scale_float: float @@ -88,9 +84,9 @@ def test_ragged_paged_attention(): ANY, # block_tables ANY, # query_start_loc ANY, # num_seqs - num_kv_pages_per_block=NUM_KV_PAGES_PER_BLOCK, - num_queries_per_block=NUM_QUERIES_PER_BLOCK, - vmem_limit_bytes=mock_vmem_limit_bytes, + num_kv_pages_per_block=None, + num_queries_per_block=None, + vmem_limit_bytes=None, use_kernel=True, sm_scale=scale, sliding_window=sliding_window,