[V1] Do not allocate beyond the max_model_len (#10730)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2024-11-28 00:13:15 -08:00 committed by GitHub
parent d9b4b3f069
commit a79b122400
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 44 additions and 18 deletions

View File

@ -23,7 +23,8 @@ def test_prefill():
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
sliding_window=False,
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=16,
)
@ -121,7 +122,8 @@ def test_decode():
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
sliding_window=False,
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=16,
)
@ -172,7 +174,8 @@ def test_evict():
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
sliding_window=False,
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=16,
)
@ -220,7 +223,8 @@ def test_hash_block_correct_reuse():
manager = KVCacheManager(
block_size=block_size,
num_gpu_blocks=1,
sliding_window=False,
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=0,
)
@ -256,7 +260,8 @@ def test_computed_blocks_not_evicted():
manager = KVCacheManager(
block_size=block_size,
num_gpu_blocks=2,
sliding_window=False,
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=0,
)
@ -303,7 +308,8 @@ def test_basic_prefix_caching_disabled():
manager = KVCacheManager(
block_size=block_size,
num_gpu_blocks=4,
sliding_window=False,
max_model_len=8192,
sliding_window=None,
enable_caching=False,
num_preallocate_tokens=0,
)
@ -342,7 +348,8 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
manager = KVCacheManager(
block_size=block_size,
num_gpu_blocks=10,
sliding_window=False,
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=num_preallocate_tokens,
)
@ -370,7 +377,8 @@ def test_cache_blocks():
manager = KVCacheManager(
block_size=block_size,
num_gpu_blocks=5,
sliding_window=False,
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=0,
)

View File

@ -17,12 +17,15 @@ class KVCacheManager:
self,
block_size: int,
num_gpu_blocks: int,
max_model_len: int,
sliding_window: Optional[int] = None,
enable_caching: bool = True,
num_preallocate_tokens: int = 64,
) -> None:
self.block_size = block_size
self.num_gpu_blocks = num_gpu_blocks
self.max_model_len = max_model_len
self.max_num_blocks_per_req = cdiv(max_model_len, block_size)
self.sliding_window = sliding_window
self.enable_caching = enable_caching
# NOTE(woosuk): To avoid frequent block allocation, we preallocate some
@ -132,7 +135,14 @@ class KVCacheManager:
num_new_blocks = min(
num_new_blocks + self.num_preallocate_blocks,
self.free_block_queue.num_free_blocks,
# Should not exceed the maximum number of blocks per request.
# This is especially because the block table has the shape
# [..., max_num_blocks_per_req].
# TODO(woosuk): Check and reject requests if
# num_prompt_tokens + max_tokens > max_model_len.
self.max_num_blocks_per_req - len(req_blocks),
)
assert num_new_blocks > 0
new_blocks = self._get_new_blocks(num_new_blocks)
req_blocks.extend(new_blocks)
@ -212,7 +222,14 @@ class KVCacheManager:
num_required_blocks + self.num_preallocate_blocks,
self.free_block_queue.num_free_blocks -
num_evictable_computed_blocks,
# Should not exceed the maximum number of blocks per request.
# This is especially because the block table has the shape
# [..., max_num_blocks_per_req].
# TODO(woosuk): Check and reject requests if
# num_prompt_tokens + max_tokens > max_model_len.
self.max_num_blocks_per_req - len(computed_blocks),
)
assert num_new_blocks > 0
# Concatenate the computed block IDs and the new block IDs.
new_blocks = self._get_new_blocks(num_new_blocks)

View File

@ -33,22 +33,23 @@ class Scheduler:
# TODO: Support LoRA.
assert lora_config is None, "V1 does not support LoRA yet."
num_gpu_blocks = cache_config.num_gpu_blocks
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
# Create the block space manager.
self.kv_cache_manager = KVCacheManager(
block_size=self.cache_config.block_size,
num_gpu_blocks=num_gpu_blocks,
sliding_window=self.cache_config.sliding_window,
enable_caching=self.cache_config.enable_prefix_caching)
self.block_size = self.cache_config.block_size
# Scheduling constraints.
self.max_num_running_reqs = self.scheduler_config.max_num_seqs
self.max_num_scheduled_tokens = \
self.scheduler_config.max_num_batched_tokens
self.max_model_len = self.scheduler_config.max_model_len
num_gpu_blocks = cache_config.num_gpu_blocks
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
# Create the KV cache manager.
self.kv_cache_manager = KVCacheManager(
block_size=self.cache_config.block_size,
num_gpu_blocks=num_gpu_blocks,
max_model_len=self.max_model_len,
sliding_window=self.cache_config.sliding_window,
enable_caching=self.cache_config.enable_prefix_caching)
self.block_size = self.cache_config.block_size
# req_id -> Request
self.requests: Dict[str, Request] = {}
# Priority queues for requests.