mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-30 14:07:07 +08:00
[V1][TPU] Pad the block_table.shape[1] so the ragged paged attention can handle correctly (#14597)
This commit is contained in:
parent
d374f04a33
commit
863d315c86
@ -23,7 +23,8 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
|
||||
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
|
||||
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
|
||||
PallasAttentionBackend,
|
||||
PallasMetadata)
|
||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
@ -138,8 +139,10 @@ class TPUModelRunner:
|
||||
device="cpu")
|
||||
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
||||
|
||||
padded_max_num_blocks_per_req = _get_padded_number(
|
||||
self.max_num_blocks_per_req, NUM_KV_PAGES_PER_BLOCK)
|
||||
self.block_table_cpu = torch.zeros(
|
||||
(self.max_num_tokens, self.max_num_blocks_per_req),
|
||||
(self.max_num_tokens, padded_max_num_blocks_per_req),
|
||||
dtype=self.input_batch.block_table.get_cpu_tensor().dtype,
|
||||
device="cpu")
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user