mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 04:45:01 +08:00
fix(v1/kv_cache): resolve async KV transfer bug in cascade attention (#23485)
Signed-off-by: Ayush Satyam <ayushsatyam146@gmail.com>
This commit is contained in:
parent
067da2d1df
commit
cd9890544b
@ -148,27 +148,22 @@ class KVCacheCoordinator(ABC):
|
||||
for manager in self.single_type_managers:
|
||||
manager.free(request_id)
|
||||
|
||||
def get_num_common_prefix_blocks(
|
||||
self, request_id: str, num_running_requests: int
|
||||
) -> list[int]:
|
||||
def get_num_common_prefix_blocks(self, running_request_id: str) -> list[int]:
|
||||
"""
|
||||
Get the number of common prefix blocks for all requests in the RUNNING
|
||||
state for each kv cache group.
|
||||
Get the number of common prefix blocks for all requests with allocated
|
||||
KV cache for each kv cache group.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
num_running_requests: The total number of requests in the RUNNING
|
||||
state.
|
||||
running_request_id: The request ID of any running request, used to
|
||||
identify the common prefix blocks.
|
||||
|
||||
Returns:
|
||||
list[int]: The number of common prefix blocks for all requests in
|
||||
the RUNNING state for each kv cache group.
|
||||
list[int]: The number of common prefix blocks for each kv cache group.
|
||||
"""
|
||||
num_blocks_per_group = [
|
||||
manager.get_num_common_prefix_blocks(request_id, num_running_requests)
|
||||
return [
|
||||
manager.get_num_common_prefix_blocks(running_request_id)
|
||||
for manager in self.single_type_managers
|
||||
]
|
||||
return num_blocks_per_group
|
||||
|
||||
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
|
||||
"""
|
||||
@ -226,9 +221,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
|
||||
)
|
||||
self.num_single_type_manager = len(self.single_type_managers)
|
||||
|
||||
def get_num_common_prefix_blocks(
|
||||
self, request_id: str, num_running_requests: int
|
||||
) -> list[int]:
|
||||
def get_num_common_prefix_blocks(self, running_request_id: str) -> list[int]:
|
||||
return [0] * self.num_single_type_manager
|
||||
|
||||
def find_longest_cache_hit(
|
||||
|
||||
@ -10,7 +10,7 @@ from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator
|
||||
from vllm.v1.core.kv_cache_utils import KVCacheBlock
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -344,49 +344,39 @@ class KVCacheManager:
|
||||
self.prefix_cache_stats.reset = True
|
||||
return True
|
||||
|
||||
def get_num_common_prefix_blocks(
|
||||
self,
|
||||
request: Request,
|
||||
num_running_requests: int,
|
||||
) -> list[int]:
|
||||
"""Calculate the number of common prefix blocks shared by all requests
|
||||
in the RUNNING state for each kv cache group.
|
||||
def get_num_common_prefix_blocks(self, running_request_id: str) -> list[int]:
|
||||
"""Calculate the number of common prefix blocks for each kv cache group.
|
||||
|
||||
The function determines this by selecting any request and iterating
|
||||
through its blocks. A block is considered a common prefix block if its
|
||||
`ref_cnt` equals the total number of requests in the RUNNING state.
|
||||
The function selects a running request and iterates through its blocks.
|
||||
A block is considered a common prefix block if ALL requests with
|
||||
allocated KV cache share it (i.e., ref_cnt equals the number of entries
|
||||
in req_to_blocks).
|
||||
|
||||
NOTE(woosuk): The number of requests in the RUNNING state is **greater
|
||||
NOTE(woosuk): The number of requests with allocated KV cache is **greater
|
||||
than or equal to** the number of requests scheduled in the current step.
|
||||
This is because the RUNNING state only indicates that:
|
||||
This is because having allocated KV cache only indicates that:
|
||||
1. The request has not yet finished, and
|
||||
2. The request holds its blocks unfreed.
|
||||
|
||||
While all scheduled requests must be in the RUNNING state, the inverse
|
||||
is not necessarily true. There may be RUNNING requests that are not
|
||||
scheduled in the current step.
|
||||
While all scheduled requests must have allocated KV cache, the inverse
|
||||
is not necessarily true. There may be requests with allocated KV cache
|
||||
that are not scheduled in the current step.
|
||||
|
||||
This can result in an edge case where the number of common prefix blocks
|
||||
is 0, even though all scheduled requests share a common prefix. This
|
||||
occurs because there may be unscheduled RUNNING requests that do not
|
||||
share the common prefix. Currently, this case cannot be easily detected,
|
||||
so the function returns 0 in such cases.
|
||||
occurs because there may be unscheduled requests that do not share the
|
||||
common prefix. Currently, this case cannot be easily detected, so the
|
||||
function returns 0 in such cases.
|
||||
|
||||
Args:
|
||||
request: Any request in the RUNNING state, used to identify the
|
||||
common prefix blocks.
|
||||
num_running_requests: The total number of requests in the RUNNING
|
||||
state. This can be different from the number of scheduled
|
||||
requests in the current step.
|
||||
running_request_id: The request ID of any running request, used to
|
||||
identify the common prefix blocks.
|
||||
|
||||
Returns:
|
||||
list[int]: The number of common prefix blocks for each kv cache
|
||||
group.
|
||||
"""
|
||||
assert request.status == RequestStatus.RUNNING
|
||||
return self.coordinator.get_num_common_prefix_blocks(
|
||||
request.request_id, num_running_requests
|
||||
)
|
||||
return self.coordinator.get_num_common_prefix_blocks(running_request_id)
|
||||
|
||||
def take_events(self) -> list[KVCacheEvent]:
|
||||
"""Take the KV cache events from the block pool.
|
||||
|
||||
@ -597,7 +597,7 @@ class Scheduler(SchedulerInterface):
|
||||
any_request = self.running[0]
|
||||
num_common_prefix_blocks = (
|
||||
self.kv_cache_manager.get_num_common_prefix_blocks(
|
||||
any_request, len(self.running)
|
||||
any_request.request_id
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -182,21 +182,17 @@ class SingleTypeKVCacheManager(ABC):
|
||||
self.num_cached_block.pop(request_id, None)
|
||||
|
||||
@abstractmethod
|
||||
def get_num_common_prefix_blocks(
|
||||
self, request_id: str, num_running_requests: int
|
||||
) -> int:
|
||||
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
|
||||
"""
|
||||
Get the number of common prefix blocks for all requests in the RUNNING
|
||||
state.
|
||||
Get the number of common prefix blocks for all requests with allocated
|
||||
KV cache.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
num_running_requests: The total number of requests in the RUNNING
|
||||
state.
|
||||
running_request_id: The request ID.
|
||||
|
||||
Returns:
|
||||
The number of common prefix blocks for all requests in the RUNNING
|
||||
state.
|
||||
The number of common prefix blocks for all requests with allocated
|
||||
KV cache.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
@ -302,13 +298,11 @@ class FullAttentionManager(SingleTypeKVCacheManager):
|
||||
# No need to remove blocks for full attention.
|
||||
pass
|
||||
|
||||
def get_num_common_prefix_blocks(
|
||||
self, request_id: str, num_running_requests: int
|
||||
) -> int:
|
||||
blocks = self.req_to_blocks[request_id]
|
||||
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
|
||||
blocks = self.req_to_blocks[running_request_id]
|
||||
num_common_blocks = 0
|
||||
for block in blocks:
|
||||
if block.ref_cnt == num_running_requests:
|
||||
if block.ref_cnt == len(self.req_to_blocks):
|
||||
num_common_blocks += 1
|
||||
else:
|
||||
break
|
||||
@ -408,9 +402,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
||||
blocks[i] = self._null_block
|
||||
self.block_pool.free_blocks(removed_blocks)
|
||||
|
||||
def get_num_common_prefix_blocks(
|
||||
self, request_id: str, num_running_requests: int
|
||||
) -> int:
|
||||
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
|
||||
"""
|
||||
NOTE(Chen): The prefix blocks are null blocks for sliding window layers.
|
||||
So it's not correct to count ref_cnt like FullAttentionManager. Return
|
||||
@ -544,9 +536,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
|
||||
blocks[i] = self._null_block
|
||||
self.block_pool.free_blocks(removed_blocks)
|
||||
|
||||
def get_num_common_prefix_blocks(
|
||||
self, request_id: str, num_running_requests: int
|
||||
) -> int:
|
||||
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
|
||||
"""
|
||||
cascade attention is not supported by chunked local attention.
|
||||
"""
|
||||
@ -596,9 +586,7 @@ class MambaManager(SingleTypeKVCacheManager):
|
||||
# (for which find_longest_cache_hit returns block_pool.null_block)
|
||||
pass
|
||||
|
||||
def get_num_common_prefix_blocks(
|
||||
self, request_id: str, num_running_requests: int
|
||||
) -> int:
|
||||
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
|
||||
"""
|
||||
cascade attention is not supported by mamba
|
||||
"""
|
||||
@ -648,9 +636,7 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
|
||||
# requests, so this method is not relevant.
|
||||
raise ValueError("Should not be called as prefix caching is disabled.")
|
||||
|
||||
def get_num_common_prefix_blocks(
|
||||
self, request_id: str, num_running_requests: int
|
||||
) -> int:
|
||||
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
|
||||
# Cross-attention blocks contain request-specific encoder states
|
||||
# and are not shared between different requests
|
||||
return 0
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user