From 863d315c867e9cc455ac63a0372af71c8c63312b Mon Sep 17 00:00:00 2001 From: iefgnoix Date: Tue, 11 Mar 2025 16:12:26 -0700 Subject: [PATCH] [V1][TPU] Pad the block_table.shape[1] so the ragged paged attention can handle correctly (#14597) --- vllm/v1/worker/tpu_model_runner.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 00869467be341..effcac7e7bdef 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -23,7 +23,8 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available -from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, +from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK, + PallasAttentionBackend, PallasMetadata) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, @@ -138,8 +139,10 @@ class TPUModelRunner: device="cpu") self.slot_mapping_np = self.slot_mapping_cpu.numpy() + padded_max_num_blocks_per_req = _get_padded_number( + self.max_num_blocks_per_req, NUM_KV_PAGES_PER_BLOCK) self.block_table_cpu = torch.zeros( - (self.max_num_tokens, self.max_num_blocks_per_req), + (self.max_num_tokens, padded_max_num_blocks_per_req), dtype=self.input_batch.block_table.get_cpu_tensor().dtype, device="cpu")