From 70ad3f9e98687776772c446530a091c4e2019e7b Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Fri, 4 Apr 2025 17:31:19 -0600 Subject: [PATCH] [Bugfix][TPU] Fix V1 TPU worker for sliding window (#16059) Signed-off-by: Michael Goin --- vllm/v1/worker/tpu_worker.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index bd24072f4c1a1..67902b41b2844 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -18,7 +18,7 @@ from vllm.logger import init_logger from vllm.model_executor import set_random_seed from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, +from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.utils import bind_kv_cache @@ -137,7 +137,7 @@ class TPUWorker: kv_caches: dict[str, torch.Tensor] = {} kv_cache_spec = self.model_runner.get_kv_cache_spec() for layer_name, layer_spec in kv_cache_spec.items(): - if isinstance(layer_spec, FullAttentionSpec): + if isinstance(layer_spec, AttentionSpec): dtype = layer_spec.dtype # Use an empty tensor instead of `None`` to force Dynamo to pass @@ -147,7 +147,8 @@ class TPUWorker: device=self.device) kv_caches[layer_name] = tpu_kv_cache else: - raise NotImplementedError + raise NotImplementedError( + f"Unsupported KV cache spec '{type(layer_spec)}'") runner_kv_caches: list[torch.Tensor] = [] bind_kv_cache(