mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 22:15:01 +08:00
[TPU][Bugfix] fix kv cache padding (#20048)
Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
parent
9f0608fc16
commit
2cc2069970
@ -48,13 +48,7 @@ class PallasAttentionBackend(AttentionBackend):
|
|||||||
) -> tuple[int, ...]:
|
) -> tuple[int, ...]:
|
||||||
padded_head_size = cdiv(
|
padded_head_size = cdiv(
|
||||||
head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
|
head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
|
||||||
num_blocks = num_blocks * head_size // padded_head_size
|
return (num_blocks, block_size, num_kv_heads * 2, padded_head_size)
|
||||||
if padded_head_size != head_size:
|
|
||||||
logger.warning_once(
|
|
||||||
"head size is padded to %d, and num_blocks is adjusted to %d"
|
|
||||||
" accordingly", padded_head_size, num_blocks)
|
|
||||||
head_size = padded_head_size
|
|
||||||
return (num_blocks, block_size, num_kv_heads * 2, head_size)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def swap_blocks(
|
def swap_blocks(
|
||||||
|
|||||||
@ -18,7 +18,8 @@ from vllm.distributed import (ensure_model_parallel_initialized,
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.model_executor import set_random_seed
|
from vllm.model_executor import set_random_seed
|
||||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
|
||||||
|
from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig,
|
from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig,
|
||||||
KVCacheSpec)
|
KVCacheSpec)
|
||||||
@ -221,7 +222,17 @@ class TPUWorker:
|
|||||||
usable_memory_size = int(total_memory_size *
|
usable_memory_size = int(total_memory_size *
|
||||||
self.cache_config.gpu_memory_utilization)
|
self.cache_config.gpu_memory_utilization)
|
||||||
tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0)
|
tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0)
|
||||||
|
head_size = self.model_config.get_head_size()
|
||||||
|
if head_size > 0:
|
||||||
|
padded_head_size = cdiv(
|
||||||
|
head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
|
||||||
|
if padded_head_size != head_size:
|
||||||
|
logger.warning_once("head size is padded to %d",
|
||||||
|
padded_head_size)
|
||||||
|
# We adjust the usable memory size for the KV cache to prevent OOM
|
||||||
|
# errors, even after padding the head_size.
|
||||||
|
tpu_kv_cache_bytes = (tpu_kv_cache_bytes * head_size //
|
||||||
|
padded_head_size)
|
||||||
return int(tpu_kv_cache_bytes)
|
return int(tpu_kv_cache_bytes)
|
||||||
|
|
||||||
def execute_model(
|
def execute_model(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user