mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-28 17:37:15 +08:00
[BugFix][Spec Decode] Improve Prefix Caching Logic in Speculative Decoding (#18668)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
b554ab736e
commit
6825d9a998
@ -174,6 +174,7 @@ class KVCacheManager:
|
||||
num_new_tokens: int,
|
||||
num_new_computed_tokens: int = 0,
|
||||
new_computed_blocks: Optional[KVCacheBlocks] = None,
|
||||
num_draft_tokens: int = 0,
|
||||
num_lookahead_tokens: int = 0,
|
||||
delay_cache_blocks: bool = False,
|
||||
) -> Optional[KVCacheBlocks]:
|
||||
@ -273,7 +274,7 @@ class KVCacheManager:
|
||||
# generated (accepted) tokens.
|
||||
self.single_type_manager.cache_blocks(
|
||||
request, self.req_to_block_hashes[request.request_id],
|
||||
num_computed_tokens + num_new_tokens - len(request.spec_token_ids))
|
||||
num_computed_tokens + num_new_tokens - num_draft_tokens)
|
||||
|
||||
return KVCacheBlocks(new_blocks)
|
||||
|
||||
|
||||
@ -227,10 +227,15 @@ class Scheduler(SchedulerInterface):
|
||||
req_index += 1
|
||||
continue
|
||||
|
||||
num_draft_tokens = max(
|
||||
num_new_tokens + request.num_computed_tokens -
|
||||
request.num_tokens, 0)
|
||||
|
||||
while True:
|
||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||
request,
|
||||
num_new_tokens,
|
||||
num_draft_tokens=num_draft_tokens,
|
||||
num_lookahead_tokens=self.num_lookahead_tokens)
|
||||
if new_blocks is None:
|
||||
# The request cannot be scheduled.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user