diff --git a/requirements/tpu.txt b/requirements/tpu.txt index 7246fc19bfa97..35d5db6c46006 100644 --- a/requirements/tpu.txt +++ b/requirements/tpu.txt @@ -17,9 +17,9 @@ ray[data] --find-links https://storage.googleapis.com/libtpu-releases/index.html --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250314%2Bcxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" -torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250314%2Bcxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" -torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250314%2Bcxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250314%2Bcxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250314%2Bcxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250314%2Bcxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" +torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250319-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" +torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250319-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" +torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250319-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250319-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250319-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250319-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index bbbdf50ac0cc7..14d3664db0d64 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -41,7 +41,7 @@ class PallasAttentionBackend(AttentionBackend): num_kv_heads: int, head_size: int, ) -> tuple[int, ...]: - return (num_blocks, block_size, num_kv_heads, head_size) + return (num_blocks, block_size, num_kv_heads * head_size) @staticmethod def swap_blocks( @@ -142,8 +142,8 @@ class PallasAttentionBackendImpl(AttentionImpl): query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = ([num_blocks, block_size, num_kv_heads, head_size], - [num_blocks, block_size, num_kv_heads, head_size]) + kv_cache = ([num_blocks, block_size, num_kv_heads * head_size], + [num_blocks, block_size, num_kv_heads * head_size]) attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] @@ -157,8 +157,6 @@ class PallasAttentionBackendImpl(AttentionImpl): assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 num_tokens, hidden_size = query.shape query = query.view(num_tokens, self.num_heads, self.head_size) - key = key.view(num_tokens, self.num_kv_heads, self.head_size) - value = value.view(num_tokens, self.num_kv_heads, self.head_size) key_cache, value_cache = kv_cache if kv_cache[0].numel() > 0: @@ -192,10 +190,10 @@ def write_to_kv_cache( """ Write the key and values to the KV cache. Args: - key: shape = [num_tokens, num_kv_heads, head_size] - value: shape = [num_tokens, num_kv_heads, head_size] - k_cache = [num_blocks, block_size, num_kv_heads, head_size] - v_cache = [num_blocks, block_size, num_kv_heads, head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + k_cache = [num_blocks, block_size, num_kv_heads * head_size] + v_cache = [num_blocks, block_size, num_kv_heads * head_size] """ torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True) @@ -203,6 +201,5 @@ def write_to_kv_cache( key_cache = key_cache.flatten(0, 1) value_cache = value_cache.flatten(0, 1) - slot_mapping = slot_mapping.flatten() key_cache.index_copy_(0, slot_mapping, key) value_cache.index_copy_(0, slot_mapping, value)