From 0d49483ea97705f531dd42383ecbb2476d7dfa2b Mon Sep 17 00:00:00 2001 From: Chengji Yao Date: Fri, 6 Jun 2025 01:20:16 -0700 Subject: [PATCH] [TPU] fix kv cache dtype in model runner (#19244) Signed-off-by: Chengji Yao --- vllm/v1/worker/tpu_model_runner.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 4c8ef0eaa781f..843bc36953b57 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -29,7 +29,8 @@ from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs, PlaceholderRange) from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.sequence import IntermediateTensors -from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, cdiv, + is_pin_memory_available) from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, PallasMetadata) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget @@ -138,6 +139,11 @@ class TPUModelRunner(LoRAModelRunnerMixin): self.pin_memory = is_pin_memory_available() self.dtype = self.model_config.dtype + if cache_config.cache_dtype == "auto": + self.kv_cache_dtype = self.dtype + else: + self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ + cache_config.cache_dtype] self._hidden_states_dtype = self.dtype self.is_multimodal_model = model_config.is_multimodal_model @@ -480,7 +486,7 @@ class TPUModelRunner(LoRAModelRunnerMixin): block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, - dtype=attn_module.dtype, + dtype=self.kv_cache_dtype, sliding_window=attn_module.sliding_window, use_mla=False, ) @@ -489,7 +495,7 @@ class TPUModelRunner(LoRAModelRunnerMixin): block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, - dtype=attn_module.dtype, + dtype=self.kv_cache_dtype, use_mla=False, ) elif attn_module.attn_type in (AttentionType.ENCODER,