fix(v1/kv_cache): resolve async KV transfer bug in cascade attention (#23485)

Signed-off-by: Ayush Satyam <ayushsatyam146@gmail.com>
This commit is contained in:
Ayush Satyam 2025-10-08 10:16:33 +05:30 committed by GitHub
parent 067da2d1df
commit cd9890544b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 41 additions and 72 deletions

View File

@ -148,27 +148,22 @@ class KVCacheCoordinator(ABC):
for manager in self.single_type_managers: for manager in self.single_type_managers:
manager.free(request_id) manager.free(request_id)
def get_num_common_prefix_blocks( def get_num_common_prefix_blocks(self, running_request_id: str) -> list[int]:
self, request_id: str, num_running_requests: int
) -> list[int]:
""" """
Get the number of common prefix blocks for all requests in the RUNNING Get the number of common prefix blocks for all requests with allocated
state for each kv cache group. KV cache for each kv cache group.
Args: Args:
request_id: The request ID. running_request_id: The request ID of any running request, used to
num_running_requests: The total number of requests in the RUNNING identify the common prefix blocks.
state.
Returns: Returns:
list[int]: The number of common prefix blocks for all requests in list[int]: The number of common prefix blocks for each kv cache group.
the RUNNING state for each kv cache group.
""" """
num_blocks_per_group = [ return [
manager.get_num_common_prefix_blocks(request_id, num_running_requests) manager.get_num_common_prefix_blocks(running_request_id)
for manager in self.single_type_managers for manager in self.single_type_managers
] ]
return num_blocks_per_group
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
""" """
@ -226,9 +221,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
) )
self.num_single_type_manager = len(self.single_type_managers) self.num_single_type_manager = len(self.single_type_managers)
def get_num_common_prefix_blocks( def get_num_common_prefix_blocks(self, running_request_id: str) -> list[int]:
self, request_id: str, num_running_requests: int
) -> list[int]:
return [0] * self.num_single_type_manager return [0] * self.num_single_type_manager
def find_longest_cache_hit( def find_longest_cache_hit(

View File

@ -10,7 +10,7 @@ from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator
from vllm.v1.core.kv_cache_utils import KVCacheBlock from vllm.v1.core.kv_cache_utils import KVCacheBlock
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request
logger = init_logger(__name__) logger = init_logger(__name__)
@ -344,49 +344,39 @@ class KVCacheManager:
self.prefix_cache_stats.reset = True self.prefix_cache_stats.reset = True
return True return True
def get_num_common_prefix_blocks( def get_num_common_prefix_blocks(self, running_request_id: str) -> list[int]:
self, """Calculate the number of common prefix blocks for each kv cache group.
request: Request,
num_running_requests: int,
) -> list[int]:
"""Calculate the number of common prefix blocks shared by all requests
in the RUNNING state for each kv cache group.
The function determines this by selecting any request and iterating The function selects a running request and iterates through its blocks.
through its blocks. A block is considered a common prefix block if its A block is considered a common prefix block if ALL requests with
`ref_cnt` equals the total number of requests in the RUNNING state. allocated KV cache share it (i.e., ref_cnt equals the number of entries
in req_to_blocks).
NOTE(woosuk): The number of requests in the RUNNING state is **greater NOTE(woosuk): The number of requests with allocated KV cache is **greater
than or equal to** the number of requests scheduled in the current step. than or equal to** the number of requests scheduled in the current step.
This is because the RUNNING state only indicates that: This is because having allocated KV cache only indicates that:
1. The request has not yet finished, and 1. The request has not yet finished, and
2. The request holds its blocks unfreed. 2. The request holds its blocks unfreed.
While all scheduled requests must be in the RUNNING state, the inverse While all scheduled requests must have allocated KV cache, the inverse
is not necessarily true. There may be RUNNING requests that are not is not necessarily true. There may be requests with allocated KV cache
scheduled in the current step. that are not scheduled in the current step.
This can result in an edge case where the number of common prefix blocks This can result in an edge case where the number of common prefix blocks
is 0, even though all scheduled requests share a common prefix. This is 0, even though all scheduled requests share a common prefix. This
occurs because there may be unscheduled RUNNING requests that do not occurs because there may be unscheduled requests that do not share the
share the common prefix. Currently, this case cannot be easily detected, common prefix. Currently, this case cannot be easily detected, so the
so the function returns 0 in such cases. function returns 0 in such cases.
Args: Args:
request: Any request in the RUNNING state, used to identify the running_request_id: The request ID of any running request, used to
common prefix blocks. identify the common prefix blocks.
num_running_requests: The total number of requests in the RUNNING
state. This can be different from the number of scheduled
requests in the current step.
Returns: Returns:
list[int]: The number of common prefix blocks for each kv cache list[int]: The number of common prefix blocks for each kv cache
group. group.
""" """
assert request.status == RequestStatus.RUNNING return self.coordinator.get_num_common_prefix_blocks(running_request_id)
return self.coordinator.get_num_common_prefix_blocks(
request.request_id, num_running_requests
)
def take_events(self) -> list[KVCacheEvent]: def take_events(self) -> list[KVCacheEvent]:
"""Take the KV cache events from the block pool. """Take the KV cache events from the block pool.

View File

@ -597,7 +597,7 @@ class Scheduler(SchedulerInterface):
any_request = self.running[0] any_request = self.running[0]
num_common_prefix_blocks = ( num_common_prefix_blocks = (
self.kv_cache_manager.get_num_common_prefix_blocks( self.kv_cache_manager.get_num_common_prefix_blocks(
any_request, len(self.running) any_request.request_id
) )
) )

View File

@ -182,21 +182,17 @@ class SingleTypeKVCacheManager(ABC):
self.num_cached_block.pop(request_id, None) self.num_cached_block.pop(request_id, None)
@abstractmethod @abstractmethod
def get_num_common_prefix_blocks( def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
self, request_id: str, num_running_requests: int
) -> int:
""" """
Get the number of common prefix blocks for all requests in the RUNNING Get the number of common prefix blocks for all requests with allocated
state. KV cache.
Args: Args:
request_id: The request ID. running_request_id: The request ID.
num_running_requests: The total number of requests in the RUNNING
state.
Returns: Returns:
The number of common prefix blocks for all requests in the RUNNING The number of common prefix blocks for all requests with allocated
state. KV cache.
""" """
raise NotImplementedError raise NotImplementedError
@ -302,13 +298,11 @@ class FullAttentionManager(SingleTypeKVCacheManager):
# No need to remove blocks for full attention. # No need to remove blocks for full attention.
pass pass
def get_num_common_prefix_blocks( def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
self, request_id: str, num_running_requests: int blocks = self.req_to_blocks[running_request_id]
) -> int:
blocks = self.req_to_blocks[request_id]
num_common_blocks = 0 num_common_blocks = 0
for block in blocks: for block in blocks:
if block.ref_cnt == num_running_requests: if block.ref_cnt == len(self.req_to_blocks):
num_common_blocks += 1 num_common_blocks += 1
else: else:
break break
@ -408,9 +402,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
blocks[i] = self._null_block blocks[i] = self._null_block
self.block_pool.free_blocks(removed_blocks) self.block_pool.free_blocks(removed_blocks)
def get_num_common_prefix_blocks( def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
self, request_id: str, num_running_requests: int
) -> int:
""" """
NOTE(Chen): The prefix blocks are null blocks for sliding window layers. NOTE(Chen): The prefix blocks are null blocks for sliding window layers.
So it's not correct to count ref_cnt like FullAttentionManager. Return So it's not correct to count ref_cnt like FullAttentionManager. Return
@ -544,9 +536,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
blocks[i] = self._null_block blocks[i] = self._null_block
self.block_pool.free_blocks(removed_blocks) self.block_pool.free_blocks(removed_blocks)
def get_num_common_prefix_blocks( def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
self, request_id: str, num_running_requests: int
) -> int:
""" """
cascade attention is not supported by chunked local attention. cascade attention is not supported by chunked local attention.
""" """
@ -596,9 +586,7 @@ class MambaManager(SingleTypeKVCacheManager):
# (for which find_longest_cache_hit returns block_pool.null_block) # (for which find_longest_cache_hit returns block_pool.null_block)
pass pass
def get_num_common_prefix_blocks( def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
self, request_id: str, num_running_requests: int
) -> int:
""" """
cascade attention is not supported by mamba cascade attention is not supported by mamba
""" """
@ -648,9 +636,7 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
# requests, so this method is not relevant. # requests, so this method is not relevant.
raise ValueError("Should not be called as prefix caching is disabled.") raise ValueError("Should not be called as prefix caching is disabled.")
def get_num_common_prefix_blocks( def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
self, request_id: str, num_running_requests: int
) -> int:
# Cross-attention blocks contain request-specific encoder states # Cross-attention blocks contain request-specific encoder states
# and are not shared between different requests # and are not shared between different requests
return 0 return 0