diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 1069578cfd29..e0aeea439794 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -48,13 +48,7 @@ class PallasAttentionBackend(AttentionBackend): ) -> tuple[int, ...]: padded_head_size = cdiv( head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT - num_blocks = num_blocks * head_size // padded_head_size - if padded_head_size != head_size: - logger.warning_once( - "head size is padded to %d, and num_blocks is adjusted to %d" - " accordingly", padded_head_size, num_blocks) - head_size = padded_head_size - return (num_blocks, block_size, num_kv_heads * 2, head_size) + return (num_blocks, block_size, num_kv_heads * 2, padded_head_size) @staticmethod def swap_blocks( diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 87af8e476707..a64ce881fe31 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -18,7 +18,8 @@ from vllm.distributed import (ensure_model_parallel_initialized, from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv +from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig, KVCacheSpec) @@ -221,7 +222,17 @@ class TPUWorker: usable_memory_size = int(total_memory_size * self.cache_config.gpu_memory_utilization) tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0) - + head_size = self.model_config.get_head_size() + if head_size > 0: + padded_head_size = cdiv( + head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT + if padded_head_size != head_size: + logger.warning_once("head size is padded to %d", + padded_head_size) + # We adjust the usable memory size for the KV cache to prevent OOM + # errors, even after padding the head_size. + tpu_kv_cache_bytes = (tpu_kv_cache_bytes * head_size // + padded_head_size) return int(tpu_kv_cache_bytes) def execute_model(