[V1][TPU] Pad the block_table.shape[1] so the ragged paged attention can handle correctly (#14597)

This commit is contained in:
iefgnoix 2025-03-11 16:12:26 -07:00 committed by GitHub
parent d374f04a33
commit 863d315c86
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -23,7 +23,8 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
PallasAttentionBackend,
PallasMetadata)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
@ -138,8 +139,10 @@ class TPUModelRunner:
device="cpu")
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
padded_max_num_blocks_per_req = _get_padded_number(
self.max_num_blocks_per_req, NUM_KV_PAGES_PER_BLOCK)
self.block_table_cpu = torch.zeros(
(self.max_num_tokens, self.max_num_blocks_per_req),
(self.max_num_tokens, padded_max_num_blocks_per_req),
dtype=self.input_batch.block_table.get_cpu_tensor().dtype,
device="cpu")