mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 01:45:54 +08:00
[Core] Don't count preempted tokens in prefix cache hit rate (#25787)
Signed-off-by: Zhuohan Li <zhuohan123@gmail.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
93ba7648d0
commit
806b292c0e
@ -27,8 +27,8 @@ class KVCacheBlocks:
|
|||||||
`blocks[i][j]` refers to the i-th kv_cache_group
|
`blocks[i][j]` refers to the i-th kv_cache_group
|
||||||
and the j-th block of tokens.We don't use block of
|
and the j-th block of tokens.We don't use block of
|
||||||
tokens as the outer dimension because it assumes all
|
tokens as the outer dimension because it assumes all
|
||||||
kv_cache_groups have the same number of blocks, which is true for now but
|
kv_cache_groups have the same number of blocks, which is true for now but
|
||||||
will be broken if we want to give different block_size to different
|
will be broken if we want to give different block_size to different
|
||||||
kv_cache_groups in the future.
|
kv_cache_groups in the future.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -184,9 +184,17 @@ class KVCacheManager:
|
|||||||
|
|
||||||
if self.log_stats:
|
if self.log_stats:
|
||||||
assert self.prefix_cache_stats is not None
|
assert self.prefix_cache_stats is not None
|
||||||
self.prefix_cache_stats.requests += 1
|
if request.num_preemptions > 0:
|
||||||
self.prefix_cache_stats.queries += request.num_tokens
|
# Previously preempted request
|
||||||
self.prefix_cache_stats.hits += num_new_computed_tokens
|
self.prefix_cache_stats.preempted_requests += 1
|
||||||
|
self.prefix_cache_stats.preempted_queries += request.num_tokens
|
||||||
|
self.prefix_cache_stats.preempted_hits += (
|
||||||
|
num_new_computed_tokens)
|
||||||
|
else:
|
||||||
|
# New request
|
||||||
|
self.prefix_cache_stats.requests += 1
|
||||||
|
self.prefix_cache_stats.queries += request.num_tokens
|
||||||
|
self.prefix_cache_stats.hits += num_new_computed_tokens
|
||||||
|
|
||||||
return KVCacheBlocks(computed_blocks), num_new_computed_tokens
|
return KVCacheBlocks(computed_blocks), num_new_computed_tokens
|
||||||
|
|
||||||
@ -209,10 +217,10 @@ class KVCacheManager:
|
|||||||
already been computed locally (i.e. new_computed_blocks).
|
already been computed locally (i.e. new_computed_blocks).
|
||||||
num_new_computed_tokens: The number of new computed tokens just
|
num_new_computed_tokens: The number of new computed tokens just
|
||||||
hitting the prefix caching, excluding external tokens.
|
hitting the prefix caching, excluding external tokens.
|
||||||
new_computed_blocks: The cached blocks for the above new computed
|
new_computed_blocks: The cached blocks for the above new computed
|
||||||
tokens.
|
tokens.
|
||||||
num_lookahead_tokens: The number of speculative tokens to allocate.
|
num_lookahead_tokens: The number of speculative tokens to allocate.
|
||||||
This is used by spec decode proposers with kv-cache such
|
This is used by spec decode proposers with kv-cache such
|
||||||
as eagle.
|
as eagle.
|
||||||
delay_cache_blocks: Whether to skip caching the blocks. This is
|
delay_cache_blocks: Whether to skip caching the blocks. This is
|
||||||
used by P/D when allocating blocks used in a KV transfer
|
used by P/D when allocating blocks used in a KV transfer
|
||||||
@ -365,7 +373,7 @@ class KVCacheManager:
|
|||||||
requests in the current step.
|
requests in the current step.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[int]: The number of common prefix blocks for each kv cache
|
list[int]: The number of common prefix blocks for each kv cache
|
||||||
group.
|
group.
|
||||||
"""
|
"""
|
||||||
assert request.status == RequestStatus.RUNNING
|
assert request.status == RequestStatus.RUNNING
|
||||||
|
|||||||
@ -251,46 +251,48 @@ class Scheduler(SchedulerInterface):
|
|||||||
req_index += 1
|
req_index += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Schedule newly needed KV blocks for the request.
|
||||||
while True:
|
while True:
|
||||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||||
request,
|
request,
|
||||||
num_new_tokens,
|
num_new_tokens,
|
||||||
num_lookahead_tokens=self.num_lookahead_tokens)
|
num_lookahead_tokens=self.num_lookahead_tokens)
|
||||||
if new_blocks is None:
|
|
||||||
# The request cannot be scheduled.
|
|
||||||
# Preempt the lowest-priority request.
|
|
||||||
if self.policy == SchedulingPolicy.PRIORITY:
|
|
||||||
preempted_req = max(
|
|
||||||
self.running,
|
|
||||||
key=lambda r: (r.priority, r.arrival_time),
|
|
||||||
)
|
|
||||||
self.running.remove(preempted_req)
|
|
||||||
if preempted_req in scheduled_running_reqs:
|
|
||||||
scheduled_running_reqs.remove(preempted_req)
|
|
||||||
else:
|
|
||||||
preempted_req = self.running.pop()
|
|
||||||
|
|
||||||
self.kv_cache_manager.free(preempted_req)
|
if new_blocks is not None:
|
||||||
self.encoder_cache_manager.free(preempted_req)
|
|
||||||
preempted_req.status = RequestStatus.PREEMPTED
|
|
||||||
preempted_req.num_computed_tokens = 0
|
|
||||||
if self.log_stats:
|
|
||||||
preempted_req.record_event(
|
|
||||||
EngineCoreEventType.PREEMPTED, scheduled_timestamp)
|
|
||||||
|
|
||||||
self.waiting.prepend_request(preempted_req)
|
|
||||||
preempted_reqs.append(preempted_req)
|
|
||||||
if preempted_req == request:
|
|
||||||
# No more request to preempt.
|
|
||||||
can_schedule = False
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
# The request can be scheduled.
|
# The request can be scheduled.
|
||||||
can_schedule = True
|
|
||||||
break
|
break
|
||||||
if not can_schedule:
|
|
||||||
|
# The request cannot be scheduled.
|
||||||
|
# Preempt the lowest-priority request.
|
||||||
|
if self.policy == SchedulingPolicy.PRIORITY:
|
||||||
|
preempted_req = max(
|
||||||
|
self.running,
|
||||||
|
key=lambda r: (r.priority, r.arrival_time),
|
||||||
|
)
|
||||||
|
self.running.remove(preempted_req)
|
||||||
|
if preempted_req in scheduled_running_reqs:
|
||||||
|
scheduled_running_reqs.remove(preempted_req)
|
||||||
|
else:
|
||||||
|
preempted_req = self.running.pop()
|
||||||
|
|
||||||
|
self.kv_cache_manager.free(preempted_req)
|
||||||
|
self.encoder_cache_manager.free(preempted_req)
|
||||||
|
preempted_req.status = RequestStatus.PREEMPTED
|
||||||
|
preempted_req.num_computed_tokens = 0
|
||||||
|
preempted_req.num_preemptions += 1
|
||||||
|
if self.log_stats:
|
||||||
|
preempted_req.record_event(EngineCoreEventType.PREEMPTED,
|
||||||
|
scheduled_timestamp)
|
||||||
|
|
||||||
|
self.waiting.prepend_request(preempted_req)
|
||||||
|
preempted_reqs.append(preempted_req)
|
||||||
|
if preempted_req == request:
|
||||||
|
# No more request to preempt. Cannot schedule this request.
|
||||||
|
break
|
||||||
|
|
||||||
|
if new_blocks is None:
|
||||||
|
# Cannot schedule this request.
|
||||||
break
|
break
|
||||||
assert new_blocks is not None
|
|
||||||
|
|
||||||
# Schedule the request.
|
# Schedule the request.
|
||||||
scheduled_running_reqs.append(request)
|
scheduled_running_reqs.append(request)
|
||||||
|
|||||||
@ -17,13 +17,19 @@ class PrefixCacheStats:
|
|||||||
"""Stores prefix cache hit statistics."""
|
"""Stores prefix cache hit statistics."""
|
||||||
# Whether reset_prefix_cache was invoked.
|
# Whether reset_prefix_cache was invoked.
|
||||||
reset: bool = False
|
reset: bool = False
|
||||||
# The number of requests in this update.
|
# The number of new requests in this update.
|
||||||
requests: int = 0
|
requests: int = 0
|
||||||
# The number of queries in these requests. Note that "queries" here
|
# The number of queries in these requests. Note that "queries" here
|
||||||
# means the number of tokens that were queried from the cache.
|
# means the number of tokens that were queried from the cache.
|
||||||
queries: int = 0
|
queries: int = 0
|
||||||
# The number of hits in these requests.
|
# The number of hits in these requests.
|
||||||
hits: int = 0
|
hits: int = 0
|
||||||
|
# The number of previously preempted requests in this update.
|
||||||
|
preempted_requests: int = 0
|
||||||
|
# The `queries` number for preempted requests.
|
||||||
|
preempted_queries: int = 0
|
||||||
|
# The `hits` number for preempted requests.
|
||||||
|
preempted_hits: int = 0
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@ -115,6 +115,9 @@ class Request:
|
|||||||
# indicates that the output is corrupted
|
# indicates that the output is corrupted
|
||||||
self.num_nans_in_logits = 0
|
self.num_nans_in_logits = 0
|
||||||
|
|
||||||
|
# The number of requests being preempted by the scheduler
|
||||||
|
self.num_preemptions = 0
|
||||||
|
|
||||||
self.block_hashes: list[BlockHash] = []
|
self.block_hashes: list[BlockHash] = []
|
||||||
self.get_hash_new_full_blocks: Optional[Callable[
|
self.get_hash_new_full_blocks: Optional[Callable[
|
||||||
[], list[BlockHash]]] = None
|
[], list[BlockHash]]] = None
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user