mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-29 00:54:00 +08:00
[TPU] fix kv_cache_update kernel block size choosing logic (#21007)
Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
parent
c11013db8b
commit
85431bd9ad
@ -326,7 +326,54 @@ def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
|
||||
return kv_cache
|
||||
|
||||
|
||||
# We can move this function to a common utils file if it's also useful for other
|
||||
# hardware.
|
||||
def dtype_bits(dtype: torch.dtype):
|
||||
if dtype.is_floating_point:
|
||||
try:
|
||||
return torch.finfo(dtype).bits
|
||||
except TypeError:
|
||||
pass
|
||||
elif dtype.is_complex:
|
||||
if dtype is torch.complex32:
|
||||
return 32
|
||||
elif dtype is torch.complex64:
|
||||
return 64
|
||||
elif dtype is torch.complex128:
|
||||
return 128
|
||||
else:
|
||||
try:
|
||||
return torch.iinfo(dtype).bits
|
||||
# torch.iinfo cannot support int4, int2, bits8...
|
||||
except TypeError:
|
||||
pass
|
||||
str_dtype = str(dtype)
|
||||
# support torch.int4, torch.int5, torch.uint5...
|
||||
if str_dtype.startswith("torch.int") or str_dtype.startswith("torch.uint"):
|
||||
return int(str_dtype[-1])
|
||||
raise TypeError(f"Getting the bit width of {dtype} is not supported")
|
||||
|
||||
|
||||
def get_dtype_packing(dtype):
|
||||
bits = dtype_bits(dtype)
|
||||
if 32 % bits != 0:
|
||||
raise ValueError(
|
||||
f"The bit width must be divisible by 32, but got bits={bits}, "
|
||||
"dtype={dtype}")
|
||||
return 32 // bits
|
||||
|
||||
|
||||
def get_page_size_bytes(block_size: int, num_kv_heads: int, head_size: int,
|
||||
kv_cache_dtype: torch.dtype) -> int:
|
||||
"""Returns the size in bytes of one page of the KV cache."""
|
||||
return block_size * num_kv_heads * head_size * kv_cache_dtype.itemsize
|
||||
padded_head_size = cdiv(head_size,
|
||||
TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
|
||||
num_combined_kv_heads = num_kv_heads * 2
|
||||
|
||||
# NOTE: for the implicit padding in XLA
|
||||
packing = get_dtype_packing(kv_cache_dtype)
|
||||
num_combined_kv_heads = cdiv(num_combined_kv_heads, packing) * packing
|
||||
|
||||
kv_cache_dtype_bits = dtype_bits(kv_cache_dtype)
|
||||
return (block_size * num_combined_kv_heads * padded_head_size *
|
||||
kv_cache_dtype_bits // 8)
|
||||
|
||||
@ -1863,8 +1863,9 @@ def _get_num_slices_per_kv_cache_update_block(page_size_bytes: int) -> int:
|
||||
out of scalar registers. Thus this function will limit the number of
|
||||
slices to 64.
|
||||
"""
|
||||
# Conservative VMEM usage limit: 32 MiB
|
||||
vmem_limit = 32 * 1024 * 1024
|
||||
# The default vmem_limit_bytes of a pallas kernel is 32MB. Here we
|
||||
# calculate num_slices_per_block based on 16MB in case any register spills.
|
||||
vmem_limit = 16 * 1024 * 1024
|
||||
num_slices_per_block = vmem_limit // page_size_bytes
|
||||
assert num_slices_per_block > 0, "Number of slices should be positive"
|
||||
num_slices_per_block = prev_power_of_2(num_slices_per_block)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user