[TPU] Optimize kv cache update kernel (#20415)

Signed-off-by: Yifei Teng <tengyifei88@gmail.com>
This commit is contained in:
Yifei Teng 2025-07-15 03:56:43 -07:00 committed by GitHub
parent 33d560001e
commit c586b55667
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 63 additions and 16 deletions

View File

@ -947,6 +947,13 @@ def next_power_of_2(n) -> int:
return 1 << (n - 1).bit_length()
def prev_power_of_2(n: int) -> int:
"""The previous power of 2 (inclusive)"""
if n <= 0:
return 0
return 1 << (n.bit_length() - 1)
def round_up(x: int, y: int) -> int:
return ((x + y - 1) // y) * y

View File

@ -324,3 +324,9 @@ def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
page_size: int,
num_slices_per_block: int) -> torch.Tensor:
return kv_cache
def get_page_size_bytes(block_size: int, num_kv_heads: int, head_size: int,
kv_cache_dtype: torch.dtype) -> int:
"""Returns the size in bytes of one page of the KV cache."""
return block_size * num_kv_heads * head_size * kv_cache_dtype.itemsize

View File

@ -31,9 +31,10 @@ from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, cdiv,
is_pin_memory_available)
is_pin_memory_available, prev_power_of_2)
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
PallasMetadata)
PallasMetadata,
get_page_size_bytes)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
KVCacheConfig, KVCacheSpec,
@ -56,8 +57,6 @@ logger = init_logger(__name__)
INVALID_TOKEN_ID = -1
# Smallest output size
MIN_NUM_SEQS = 8
# Block size used for kv cache updating kernel
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK = 8
#########################################################
@ -139,7 +138,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.pin_memory = is_pin_memory_available()
self.dtype = self.model_config.dtype
if cache_config.cache_dtype == "auto":
self.kv_cache_dtype = self.dtype
model_dtype = self.dtype
if isinstance(model_dtype, str):
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
else:
self.kv_cache_dtype = model_dtype
else:
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
cache_config.cache_dtype]
@ -192,6 +195,14 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.max_num_encoder_input_tokens = encoder_compute_budget
self.encoder_cache_size = encoder_cache_size
self._num_slices_per_kv_cache_update_block = \
_get_num_slices_per_kv_cache_update_block(get_page_size_bytes(
block_size=self.block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
kv_cache_dtype=self.kv_cache_dtype,
))
# Lazy initialization
self.model: nn.Module # Set after load_model
self.kv_caches: list[torch.Tensor] = []
@ -719,7 +730,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
num_kv_update_slices = slot_mapping_metadata.shape[0]
padded_num_slices = _get_padded_num_kv_cache_update_slices(
padded_total_num_scheduled_tokens, self.max_num_reqs,
self.block_size)
self.block_size, self._num_slices_per_kv_cache_update_block)
slot_mapping_metadata = np.pad(
slot_mapping_metadata,
[[0, padded_num_slices - len(slot_mapping_metadata)], [0, 0]],
@ -750,8 +761,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
num_kv_update_slices=torch.tensor([num_kv_update_slices],
dtype=torch.int32,
device=self.device),
num_slices_per_kv_cache_update_block=
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK,
num_slices_per_kv_cache_update_block=self.
_num_slices_per_kv_cache_update_block,
)
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
# request in the batch. While we should not sample any token from this
@ -1197,7 +1208,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
position_ids = torch.zeros(num_tokens,
dtype=torch.int32).to(self.device)
padded_num_slices = _get_padded_num_kv_cache_update_slices(
num_tokens, self.max_num_reqs, self.block_size)
num_tokens, self.max_num_reqs, self.block_size,
self._num_slices_per_kv_cache_update_block)
num_kv_update_slices = torch.tensor([padded_num_slices],
dtype=torch.int32).to(self.device)
slot_mapping = torch.zeros((3, padded_num_slices),
@ -1220,8 +1232,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
query_start_loc=query_start_loc,
num_seqs=num_seqs,
num_kv_update_slices=num_kv_update_slices,
num_slices_per_kv_cache_update_block=
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK,
num_slices_per_kv_cache_update_block=self.
_num_slices_per_kv_cache_update_block,
)
if self.is_multimodal_model:
@ -1826,19 +1838,41 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int:
return paddings[index]
def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int,
page_size: int) -> int:
def _get_padded_num_kv_cache_update_slices(
num_tokens: int, max_num_reqs: int, page_size: int,
num_slices_per_kv_cache_update_block: int) -> int:
"""Calculates the padded number of KV cache update slices to avoid
recompilation."""
padded_num_slices = 2 * max_num_reqs + num_tokens // page_size
padded_num_slices = min(padded_num_slices, num_tokens)
padded_num_slices = (
padded_num_slices + NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK - 1
) // NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK * \
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK
padded_num_slices + num_slices_per_kv_cache_update_block - 1
) // num_slices_per_kv_cache_update_block * \
num_slices_per_kv_cache_update_block
return padded_num_slices
def _get_num_slices_per_kv_cache_update_block(page_size_bytes: int) -> int:
"""Find the optimum number of slices to copy per Pallas program instance.
Increasing the number of slices copied in one instance of the kernel program
will increase HBM bandwidth utilization via more in-flight DMAs.
However, it will also use more VMEM, and experimentally, we observed
performance regression at 128 slices on v6e, likely due to running
out of scalar registers. Thus this function will limit the number of
slices to 64.
"""
# Conservative VMEM usage limit: 32 MiB
vmem_limit = 32 * 1024 * 1024
num_slices_per_block = vmem_limit // page_size_bytes
assert num_slices_per_block > 0, "Number of slices should be positive"
num_slices_per_block = prev_power_of_2(num_slices_per_block)
if num_slices_per_block > 64:
num_slices_per_block = 64
return num_slices_per_block
def replace_set_lora(model):
def _tpu_set_lora(