mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-29 04:00:54 +08:00
[BugFix] Honor enable_caching in connector-delayed kvcache load case (#19435)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
d70bc7c029
commit
7e8d97dd3f
@ -381,10 +381,11 @@ class KVCacheManager:
|
||||
self.coordinator.get_blocks(request_id)).get_block_ids()
|
||||
|
||||
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
|
||||
"""Cache the blocks for the request."""
|
||||
block_hashes = self.req_to_block_hashes[request.request_id]
|
||||
self.coordinator.cache_blocks(request, block_hashes,
|
||||
num_computed_tokens)
|
||||
"""Cache the blocks for the request, if enabled."""
|
||||
if self.enable_caching:
|
||||
block_hashes = self.req_to_block_hashes[request.request_id]
|
||||
self.coordinator.cache_blocks(request, block_hashes,
|
||||
num_computed_tokens)
|
||||
|
||||
def create_empty_block_list(self) -> KVCacheBlocks:
|
||||
"""Creates a new KVCacheBlocks instance with no blocks."""
|
||||
|
||||
@ -1015,6 +1015,7 @@ class Scheduler(SchedulerInterface):
|
||||
num_computed_tokens = min(num_computed_tokens, request.num_tokens)
|
||||
if num_computed_tokens == request.num_tokens:
|
||||
num_computed_tokens -= 1
|
||||
# This will cache the blocks iff caching is enabled.
|
||||
self.kv_cache_manager.cache_blocks(request, num_computed_tokens)
|
||||
|
||||
# Update the request state for scheduling.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user