fix(v1/kv_cache): resolve async KV transfer bug in cascade attention (#23485)

Signed-off-by: Ayush Satyam <ayushsatyam146@gmail.com>
This commit is contained in:
Ayush Satyam 2025-10-08 10:16:33 +05:30 committed by GitHub
parent 067da2d1df
commit cd9890544b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 41 additions and 72 deletions

View File

@ -148,27 +148,22 @@ class KVCacheCoordinator(ABC):
for manager in self.single_type_managers:
manager.free(request_id)
def get_num_common_prefix_blocks(
self, request_id: str, num_running_requests: int
) -> list[int]:
def get_num_common_prefix_blocks(self, running_request_id: str) -> list[int]:
"""
Get the number of common prefix blocks for all requests in the RUNNING
state for each kv cache group.
Get the number of common prefix blocks for all requests with allocated
KV cache for each kv cache group.
Args:
request_id: The request ID.
num_running_requests: The total number of requests in the RUNNING
state.
running_request_id: The request ID of any running request, used to
identify the common prefix blocks.
Returns:
list[int]: The number of common prefix blocks for all requests in
the RUNNING state for each kv cache group.
list[int]: The number of common prefix blocks for each kv cache group.
"""
num_blocks_per_group = [
manager.get_num_common_prefix_blocks(request_id, num_running_requests)
return [
manager.get_num_common_prefix_blocks(running_request_id)
for manager in self.single_type_managers
]
return num_blocks_per_group
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
"""
@ -226,9 +221,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
)
self.num_single_type_manager = len(self.single_type_managers)
def get_num_common_prefix_blocks(
self, request_id: str, num_running_requests: int
) -> list[int]:
def get_num_common_prefix_blocks(self, running_request_id: str) -> list[int]:
return [0] * self.num_single_type_manager
def find_longest_cache_hit(

View File

@ -10,7 +10,7 @@ from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator
from vllm.v1.core.kv_cache_utils import KVCacheBlock
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request, RequestStatus
from vllm.v1.request import Request
logger = init_logger(__name__)
@ -344,49 +344,39 @@ class KVCacheManager:
self.prefix_cache_stats.reset = True
return True
def get_num_common_prefix_blocks(
self,
request: Request,
num_running_requests: int,
) -> list[int]:
"""Calculate the number of common prefix blocks shared by all requests
in the RUNNING state for each kv cache group.
def get_num_common_prefix_blocks(self, running_request_id: str) -> list[int]:
"""Calculate the number of common prefix blocks for each kv cache group.
The function determines this by selecting any request and iterating
through its blocks. A block is considered a common prefix block if its
`ref_cnt` equals the total number of requests in the RUNNING state.
The function selects a running request and iterates through its blocks.
A block is considered a common prefix block if ALL requests with
allocated KV cache share it (i.e., ref_cnt equals the number of entries
in req_to_blocks).
NOTE(woosuk): The number of requests in the RUNNING state is **greater
NOTE(woosuk): The number of requests with allocated KV cache is **greater
than or equal to** the number of requests scheduled in the current step.
This is because the RUNNING state only indicates that:
This is because having allocated KV cache only indicates that:
1. The request has not yet finished, and
2. The request holds its blocks unfreed.
While all scheduled requests must be in the RUNNING state, the inverse
is not necessarily true. There may be RUNNING requests that are not
scheduled in the current step.
While all scheduled requests must have allocated KV cache, the inverse
is not necessarily true. There may be requests with allocated KV cache
that are not scheduled in the current step.
This can result in an edge case where the number of common prefix blocks
is 0, even though all scheduled requests share a common prefix. This
occurs because there may be unscheduled RUNNING requests that do not
share the common prefix. Currently, this case cannot be easily detected,
so the function returns 0 in such cases.
occurs because there may be unscheduled requests that do not share the
common prefix. Currently, this case cannot be easily detected, so the
function returns 0 in such cases.
Args:
request: Any request in the RUNNING state, used to identify the
common prefix blocks.
num_running_requests: The total number of requests in the RUNNING
state. This can be different from the number of scheduled
requests in the current step.
running_request_id: The request ID of any running request, used to
identify the common prefix blocks.
Returns:
list[int]: The number of common prefix blocks for each kv cache
group.
"""
assert request.status == RequestStatus.RUNNING
return self.coordinator.get_num_common_prefix_blocks(
request.request_id, num_running_requests
)
return self.coordinator.get_num_common_prefix_blocks(running_request_id)
def take_events(self) -> list[KVCacheEvent]:
"""Take the KV cache events from the block pool.

View File

@ -597,7 +597,7 @@ class Scheduler(SchedulerInterface):
any_request = self.running[0]
num_common_prefix_blocks = (
self.kv_cache_manager.get_num_common_prefix_blocks(
any_request, len(self.running)
any_request.request_id
)
)

View File

@ -182,21 +182,17 @@ class SingleTypeKVCacheManager(ABC):
self.num_cached_block.pop(request_id, None)
@abstractmethod
def get_num_common_prefix_blocks(
self, request_id: str, num_running_requests: int
) -> int:
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
"""
Get the number of common prefix blocks for all requests in the RUNNING
state.
Get the number of common prefix blocks for all requests with allocated
KV cache.
Args:
request_id: The request ID.
num_running_requests: The total number of requests in the RUNNING
state.
running_request_id: The request ID.
Returns:
The number of common prefix blocks for all requests in the RUNNING
state.
The number of common prefix blocks for all requests with allocated
KV cache.
"""
raise NotImplementedError
@ -302,13 +298,11 @@ class FullAttentionManager(SingleTypeKVCacheManager):
# No need to remove blocks for full attention.
pass
def get_num_common_prefix_blocks(
self, request_id: str, num_running_requests: int
) -> int:
blocks = self.req_to_blocks[request_id]
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
blocks = self.req_to_blocks[running_request_id]
num_common_blocks = 0
for block in blocks:
if block.ref_cnt == num_running_requests:
if block.ref_cnt == len(self.req_to_blocks):
num_common_blocks += 1
else:
break
@ -408,9 +402,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
blocks[i] = self._null_block
self.block_pool.free_blocks(removed_blocks)
def get_num_common_prefix_blocks(
self, request_id: str, num_running_requests: int
) -> int:
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
"""
NOTE(Chen): The prefix blocks are null blocks for sliding window layers.
So it's not correct to count ref_cnt like FullAttentionManager. Return
@ -544,9 +536,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
blocks[i] = self._null_block
self.block_pool.free_blocks(removed_blocks)
def get_num_common_prefix_blocks(
self, request_id: str, num_running_requests: int
) -> int:
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
"""
cascade attention is not supported by chunked local attention.
"""
@ -596,9 +586,7 @@ class MambaManager(SingleTypeKVCacheManager):
# (for which find_longest_cache_hit returns block_pool.null_block)
pass
def get_num_common_prefix_blocks(
self, request_id: str, num_running_requests: int
) -> int:
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
"""
cascade attention is not supported by mamba
"""
@ -648,9 +636,7 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
# requests, so this method is not relevant.
raise ValueError("Should not be called as prefix caching is disabled.")
def get_num_common_prefix_blocks(
self, request_id: str, num_running_requests: int
) -> int:
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
# Cross-attention blocks contain request-specific encoder states
# and are not shared between different requests
return 0