mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 07:24:54 +08:00
fix(v1/kv_cache): resolve async KV transfer bug in cascade attention (#23485)
Signed-off-by: Ayush Satyam <ayushsatyam146@gmail.com>
This commit is contained in:
parent
067da2d1df
commit
cd9890544b
@ -148,27 +148,22 @@ class KVCacheCoordinator(ABC):
|
|||||||
for manager in self.single_type_managers:
|
for manager in self.single_type_managers:
|
||||||
manager.free(request_id)
|
manager.free(request_id)
|
||||||
|
|
||||||
def get_num_common_prefix_blocks(
|
def get_num_common_prefix_blocks(self, running_request_id: str) -> list[int]:
|
||||||
self, request_id: str, num_running_requests: int
|
|
||||||
) -> list[int]:
|
|
||||||
"""
|
"""
|
||||||
Get the number of common prefix blocks for all requests in the RUNNING
|
Get the number of common prefix blocks for all requests with allocated
|
||||||
state for each kv cache group.
|
KV cache for each kv cache group.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request_id: The request ID.
|
running_request_id: The request ID of any running request, used to
|
||||||
num_running_requests: The total number of requests in the RUNNING
|
identify the common prefix blocks.
|
||||||
state.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[int]: The number of common prefix blocks for all requests in
|
list[int]: The number of common prefix blocks for each kv cache group.
|
||||||
the RUNNING state for each kv cache group.
|
|
||||||
"""
|
"""
|
||||||
num_blocks_per_group = [
|
return [
|
||||||
manager.get_num_common_prefix_blocks(request_id, num_running_requests)
|
manager.get_num_common_prefix_blocks(running_request_id)
|
||||||
for manager in self.single_type_managers
|
for manager in self.single_type_managers
|
||||||
]
|
]
|
||||||
return num_blocks_per_group
|
|
||||||
|
|
||||||
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
|
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
|
||||||
"""
|
"""
|
||||||
@ -226,9 +221,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
|
|||||||
)
|
)
|
||||||
self.num_single_type_manager = len(self.single_type_managers)
|
self.num_single_type_manager = len(self.single_type_managers)
|
||||||
|
|
||||||
def get_num_common_prefix_blocks(
|
def get_num_common_prefix_blocks(self, running_request_id: str) -> list[int]:
|
||||||
self, request_id: str, num_running_requests: int
|
|
||||||
) -> list[int]:
|
|
||||||
return [0] * self.num_single_type_manager
|
return [0] * self.num_single_type_manager
|
||||||
|
|
||||||
def find_longest_cache_hit(
|
def find_longest_cache_hit(
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator
|
|||||||
from vllm.v1.core.kv_cache_utils import KVCacheBlock
|
from vllm.v1.core.kv_cache_utils import KVCacheBlock
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
from vllm.v1.metrics.stats import PrefixCacheStats
|
from vllm.v1.metrics.stats import PrefixCacheStats
|
||||||
from vllm.v1.request import Request, RequestStatus
|
from vllm.v1.request import Request
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -344,49 +344,39 @@ class KVCacheManager:
|
|||||||
self.prefix_cache_stats.reset = True
|
self.prefix_cache_stats.reset = True
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def get_num_common_prefix_blocks(
|
def get_num_common_prefix_blocks(self, running_request_id: str) -> list[int]:
|
||||||
self,
|
"""Calculate the number of common prefix blocks for each kv cache group.
|
||||||
request: Request,
|
|
||||||
num_running_requests: int,
|
|
||||||
) -> list[int]:
|
|
||||||
"""Calculate the number of common prefix blocks shared by all requests
|
|
||||||
in the RUNNING state for each kv cache group.
|
|
||||||
|
|
||||||
The function determines this by selecting any request and iterating
|
The function selects a running request and iterates through its blocks.
|
||||||
through its blocks. A block is considered a common prefix block if its
|
A block is considered a common prefix block if ALL requests with
|
||||||
`ref_cnt` equals the total number of requests in the RUNNING state.
|
allocated KV cache share it (i.e., ref_cnt equals the number of entries
|
||||||
|
in req_to_blocks).
|
||||||
|
|
||||||
NOTE(woosuk): The number of requests in the RUNNING state is **greater
|
NOTE(woosuk): The number of requests with allocated KV cache is **greater
|
||||||
than or equal to** the number of requests scheduled in the current step.
|
than or equal to** the number of requests scheduled in the current step.
|
||||||
This is because the RUNNING state only indicates that:
|
This is because having allocated KV cache only indicates that:
|
||||||
1. The request has not yet finished, and
|
1. The request has not yet finished, and
|
||||||
2. The request holds its blocks unfreed.
|
2. The request holds its blocks unfreed.
|
||||||
|
|
||||||
While all scheduled requests must be in the RUNNING state, the inverse
|
While all scheduled requests must have allocated KV cache, the inverse
|
||||||
is not necessarily true. There may be RUNNING requests that are not
|
is not necessarily true. There may be requests with allocated KV cache
|
||||||
scheduled in the current step.
|
that are not scheduled in the current step.
|
||||||
|
|
||||||
This can result in an edge case where the number of common prefix blocks
|
This can result in an edge case where the number of common prefix blocks
|
||||||
is 0, even though all scheduled requests share a common prefix. This
|
is 0, even though all scheduled requests share a common prefix. This
|
||||||
occurs because there may be unscheduled RUNNING requests that do not
|
occurs because there may be unscheduled requests that do not share the
|
||||||
share the common prefix. Currently, this case cannot be easily detected,
|
common prefix. Currently, this case cannot be easily detected, so the
|
||||||
so the function returns 0 in such cases.
|
function returns 0 in such cases.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: Any request in the RUNNING state, used to identify the
|
running_request_id: The request ID of any running request, used to
|
||||||
common prefix blocks.
|
identify the common prefix blocks.
|
||||||
num_running_requests: The total number of requests in the RUNNING
|
|
||||||
state. This can be different from the number of scheduled
|
|
||||||
requests in the current step.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[int]: The number of common prefix blocks for each kv cache
|
list[int]: The number of common prefix blocks for each kv cache
|
||||||
group.
|
group.
|
||||||
"""
|
"""
|
||||||
assert request.status == RequestStatus.RUNNING
|
return self.coordinator.get_num_common_prefix_blocks(running_request_id)
|
||||||
return self.coordinator.get_num_common_prefix_blocks(
|
|
||||||
request.request_id, num_running_requests
|
|
||||||
)
|
|
||||||
|
|
||||||
def take_events(self) -> list[KVCacheEvent]:
|
def take_events(self) -> list[KVCacheEvent]:
|
||||||
"""Take the KV cache events from the block pool.
|
"""Take the KV cache events from the block pool.
|
||||||
|
|||||||
@ -597,7 +597,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
any_request = self.running[0]
|
any_request = self.running[0]
|
||||||
num_common_prefix_blocks = (
|
num_common_prefix_blocks = (
|
||||||
self.kv_cache_manager.get_num_common_prefix_blocks(
|
self.kv_cache_manager.get_num_common_prefix_blocks(
|
||||||
any_request, len(self.running)
|
any_request.request_id
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -182,21 +182,17 @@ class SingleTypeKVCacheManager(ABC):
|
|||||||
self.num_cached_block.pop(request_id, None)
|
self.num_cached_block.pop(request_id, None)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_num_common_prefix_blocks(
|
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
|
||||||
self, request_id: str, num_running_requests: int
|
|
||||||
) -> int:
|
|
||||||
"""
|
"""
|
||||||
Get the number of common prefix blocks for all requests in the RUNNING
|
Get the number of common prefix blocks for all requests with allocated
|
||||||
state.
|
KV cache.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request_id: The request ID.
|
running_request_id: The request ID.
|
||||||
num_running_requests: The total number of requests in the RUNNING
|
|
||||||
state.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The number of common prefix blocks for all requests in the RUNNING
|
The number of common prefix blocks for all requests with allocated
|
||||||
state.
|
KV cache.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -302,13 +298,11 @@ class FullAttentionManager(SingleTypeKVCacheManager):
|
|||||||
# No need to remove blocks for full attention.
|
# No need to remove blocks for full attention.
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_num_common_prefix_blocks(
|
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
|
||||||
self, request_id: str, num_running_requests: int
|
blocks = self.req_to_blocks[running_request_id]
|
||||||
) -> int:
|
|
||||||
blocks = self.req_to_blocks[request_id]
|
|
||||||
num_common_blocks = 0
|
num_common_blocks = 0
|
||||||
for block in blocks:
|
for block in blocks:
|
||||||
if block.ref_cnt == num_running_requests:
|
if block.ref_cnt == len(self.req_to_blocks):
|
||||||
num_common_blocks += 1
|
num_common_blocks += 1
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
@ -408,9 +402,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
|||||||
blocks[i] = self._null_block
|
blocks[i] = self._null_block
|
||||||
self.block_pool.free_blocks(removed_blocks)
|
self.block_pool.free_blocks(removed_blocks)
|
||||||
|
|
||||||
def get_num_common_prefix_blocks(
|
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
|
||||||
self, request_id: str, num_running_requests: int
|
|
||||||
) -> int:
|
|
||||||
"""
|
"""
|
||||||
NOTE(Chen): The prefix blocks are null blocks for sliding window layers.
|
NOTE(Chen): The prefix blocks are null blocks for sliding window layers.
|
||||||
So it's not correct to count ref_cnt like FullAttentionManager. Return
|
So it's not correct to count ref_cnt like FullAttentionManager. Return
|
||||||
@ -544,9 +536,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
|
|||||||
blocks[i] = self._null_block
|
blocks[i] = self._null_block
|
||||||
self.block_pool.free_blocks(removed_blocks)
|
self.block_pool.free_blocks(removed_blocks)
|
||||||
|
|
||||||
def get_num_common_prefix_blocks(
|
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
|
||||||
self, request_id: str, num_running_requests: int
|
|
||||||
) -> int:
|
|
||||||
"""
|
"""
|
||||||
cascade attention is not supported by chunked local attention.
|
cascade attention is not supported by chunked local attention.
|
||||||
"""
|
"""
|
||||||
@ -596,9 +586,7 @@ class MambaManager(SingleTypeKVCacheManager):
|
|||||||
# (for which find_longest_cache_hit returns block_pool.null_block)
|
# (for which find_longest_cache_hit returns block_pool.null_block)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_num_common_prefix_blocks(
|
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
|
||||||
self, request_id: str, num_running_requests: int
|
|
||||||
) -> int:
|
|
||||||
"""
|
"""
|
||||||
cascade attention is not supported by mamba
|
cascade attention is not supported by mamba
|
||||||
"""
|
"""
|
||||||
@ -648,9 +636,7 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
|
|||||||
# requests, so this method is not relevant.
|
# requests, so this method is not relevant.
|
||||||
raise ValueError("Should not be called as prefix caching is disabled.")
|
raise ValueError("Should not be called as prefix caching is disabled.")
|
||||||
|
|
||||||
def get_num_common_prefix_blocks(
|
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
|
||||||
self, request_id: str, num_running_requests: int
|
|
||||||
) -> int:
|
|
||||||
# Cross-attention blocks contain request-specific encoder states
|
# Cross-attention blocks contain request-specific encoder states
|
||||||
# and are not shared between different requests
|
# and are not shared between different requests
|
||||||
return 0
|
return 0
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user