[Core] Support reset_prefix_cache (#12284)

This commit is contained in:
Cody Yu 2025-01-22 10:52:27 -08:00 committed by GitHub
parent 96f6a7596f
commit 7206ce4ce1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 300 additions and 21 deletions

View File

@ -796,6 +796,44 @@ class TestPrefixCachingBlockAllocator:
block_hashes=block_hashes_seq1)
assert len(cached_blocks) == len(blocks_seq1) - num_evicted_blocks
# Test reset prefix cache
@staticmethod
@pytest.mark.parametrize("num_blocks", [10])
@pytest.mark.parametrize("block_size", [16])
def test_reset_prefix_cache(num_blocks: int, block_size: int):
"""This test case simulates the case of resetting the prefix cache."""
allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
block_size=block_size)
token_ids = list(range(3 * block_size))
first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=token_ids,
allocator=allocator,
)
second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=token_ids,
allocator=allocator,
)
# Free each block in the first chain.
for block in first_chain:
allocator.free(block)
# Failed to reset prefix cache because some blocks are not freed yet.
assert not allocator.reset_prefix_cache()
assert allocator.get_prefix_cache_hit_rate() > 0.0
# Free each block in the second chain.
for block in second_chain:
allocator.free(block)
# Reset prefix cache.
assert allocator.reset_prefix_cache()
assert allocator.get_prefix_cache_hit_rate() == 0.0
@staticmethod
def create_immutable_chain(
block_size: int,

View File

@ -587,3 +587,42 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
assert {block.ref_cnt for block in block_part1[:3]} == {1}
# Block 3-5 are free.
assert {block.ref_cnt for block in block_part1[3:]} == {0}
def test_reset_prefix_cache():
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=0,
)
full_block_token_ids = [i for i in range(3) for _ in range(16)]
unique_token_ids = [3] * 7
all_token_ids = full_block_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids)
blocks = manager.allocate_slots(req0, 55, [])
assert [b.block_id for b in blocks] == [0, 1, 2, 3]
unique_token_ids = [4] * 7
all_token_ids = full_block_token_ids + unique_token_ids
req1 = make_request("1", all_token_ids)
computed_blocks, _ = manager.get_computed_blocks(req1)
assert len(req1.kv_block_hashes) == 3
assert len(computed_blocks) == 3
blocks = manager.allocate_slots(req1, 7, computed_blocks)
assert [b.block_id for b in blocks] == [4]
# Failed to reset prefix cache because some blocks are not freed yet.
assert not manager.reset_prefix_cache()
assert manager.cached_block_hash_to_block
# Free the blocks.
manager.free(req0)
manager.free(req1)
assert manager.reset_prefix_cache()
assert not manager.cached_block_hash_to_block
assert all([blk.block_hash is None for blk in manager.block_pool])

View File

@ -339,6 +339,13 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
assert device in self._allocators
return self._allocators[device].get_prefix_cache_hit_rate()
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache for all devices."""
success = True
for allocator in self._allocators.values():
success = success and allocator.reset_prefix_cache()
return success
def get_and_reset_swaps(self) -> List[Tuple[int, int]]:
"""Returns and clears the mapping of source to destination block IDs.
Will be called after every swapping operations for now, and after every

View File

@ -192,6 +192,11 @@ class BlockAllocator(ABC):
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass
@abstractmethod
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache."""
pass
class NoFreeBlocksError(ValueError):
pass
@ -297,6 +302,11 @@ class DeviceAwareBlockAllocator(ABC):
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass
@abstractmethod
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache."""
pass
@abstractmethod
def find_cached_blocks_prefix(
self,

View File

@ -1,5 +1,5 @@
from collections import deque
from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple
from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple, Union
from vllm.core.block.common import (BlockPool, CopyOnWriteTracker, RefCounter,
get_all_blocks_recursively)
@ -136,16 +136,18 @@ class NaiveBlockAllocator(BlockAllocator):
self._refcounter.incr(block_id)
return block_id
def _free_block_id(self, block: Block) -> None:
block_id = block.block_id
def _free_block_id(self, block: Union[Block, BlockId]) -> None:
if isinstance(block, Block):
block_id = block.block_id
block.block_id = None
else:
block_id = block
assert block_id is not None
refcount = self._refcounter.decr(block_id)
if refcount == 0:
self._free_block_indices.appendleft(block_id)
block.block_id = None
def free(self, block: Block, keep_block_object: bool = False) -> None:
# Release the physical block id
self._free_block_id(block)
@ -154,6 +156,9 @@ class NaiveBlockAllocator(BlockAllocator):
if not keep_block_object:
self._block_pool.free_block(block)
def free_block_id(self, block_id: BlockId) -> None:
self._free_block_id(block_id)
def fork(self, last_block: Block) -> List[Block]:
"""Creates a new sequence of blocks that shares the same underlying
memory as the original sequence.
@ -325,6 +330,10 @@ class NaiveBlockAllocator(BlockAllocator):
def get_prefix_cache_hit_rate(self) -> float:
return -1
def reset_prefix_cache(self) -> bool:
"""No prefix cache for naive block allocator."""
return True
def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]:
# Not applicable for naive block allocator.
return []

View File

@ -12,6 +12,7 @@ from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId, Device,
from vllm.core.block.naive_block import (BlockPool, NaiveBlock,
NaiveBlockAllocator)
from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor
from vllm.logger import init_logger
from vllm.sequence import Sequence
PrefixHash = int
@ -21,6 +22,8 @@ PrefixHash = int
# then we know this block hasn't been accessed yet.
_DEFAULT_LAST_ACCESSED_TIME = -1
logger = init_logger(__name__)
class BlockTracker:
"""Used to track the status of a block inside the prefix caching allocator
@ -105,7 +108,8 @@ class PrefixCachingBlockAllocator(BlockAllocator):
# Evitor used to maintain how we want to handle those computed blocks
# if we find memory pressure is high.
self.evictor: Evictor = make_evictor(eviction_policy)
self.eviction_policy = eviction_policy
self.evictor: Evictor = make_evictor(self.eviction_policy)
# We share the refcounter between allocators. This allows us to promote
# blocks originally allocated in the hashless allocator to immutable
@ -428,6 +432,44 @@ class PrefixCachingBlockAllocator(BlockAllocator):
def get_prefix_cache_hit_rate(self) -> float:
return self.metric_data.get_hit_rate()
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache. This function may be used in RLHF
flows to invalid prefix caching after the weights are updated,
or used for resetting prefix caching status for benchmarking.
Returns:
bool: True if the prefix cache is successfully reset,
False otherwise.
"""
num_used_blocks = (self.get_num_total_blocks() -
self.get_num_free_blocks())
if num_used_blocks > 0:
logger.warning(
"Failed to reset prefix cache because some "
"blocks (%d) are not freed yet", num_used_blocks)
return False
# Free all blocks in the evictor.
while (block_id :=
self._maybe_allocate_evicted_block_id()) is not None:
self._hashless_allocator.free_block_id(block_id)
# Should not have any cached blocks because all blocks are evicted.
assert not self._cached_blocks
# Reset the evictor.
self.evictor = make_evictor(self.eviction_policy)
# Reset the block tracker.
for block_id in self._block_tracker:
self._block_tracker[block_id] = BlockTracker()
# Reset the metrics.
self.metric_data = CacheMetricData()
logger.info("Successfully reset prefix cache")
return True
def is_block_cached(self, block: Block) -> bool:
assert block.content_hash is not None
return block.content_hash in self._cached_blocks

View File

@ -455,6 +455,9 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_allocator.get_prefix_cache_hit_rate(device)
def reset_prefix_cache(self) -> bool:
return self.block_allocator.reset_prefix_cache()
def _can_swap(self,
seq_group: SequenceGroup,
device: Device,

View File

@ -122,6 +122,11 @@ class BlockSpaceManager(ABC):
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass
@abstractmethod
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache for all devices."""
pass
@abstractmethod
def get_num_cached_tokens(self, seq: Sequence) -> int:
pass

View File

@ -90,5 +90,8 @@ class PlaceholderBlockSpaceManager(BlockSpaceManager):
def get_prefix_cache_hit_rate(self, device: Device) -> float:
return -1
def reset_prefix_cache(self) -> bool:
return True
def get_num_cached_tokens(self, seq: Sequence) -> int:
return 0

View File

@ -504,6 +504,9 @@ class Scheduler:
def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_manager.get_prefix_cache_hit_rate(device)
def reset_prefix_cache(self) -> bool:
return self.block_manager.reset_prefix_cache()
def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped)

View File

@ -1182,6 +1182,9 @@ class AsyncLLMEngine(EngineClient):
async def stop_profile(self) -> None:
self.engine.stop_profile()
async def reset_prefix_cache(self) -> None:
self.engine.reset_prefix_cache()
async def add_lora(self, lora_request: LoRARequest) -> None:
self.engine.add_lora(lora_request)

View File

@ -914,6 +914,14 @@ class LLMEngine:
"""
return self.scheduler[virtual_engine].has_unfinished_seqs()
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache for all devices."""
success = True
for scheduler in self.scheduler:
success = success and scheduler.reset_prefix_cache()
return success
@staticmethod
def _process_sequence_group_outputs(
seq_group: SequenceGroup,

View File

@ -121,6 +121,10 @@ class RPCUProfileRequest(Enum):
STOP_PROFILE = 2
class RPCResetPrefixCacheRequest(Enum):
RESET_PREFIX_CACHE = 1
@dataclass
class RPCLoadAdapterRequest:
lora_request: LoRARequest
@ -134,7 +138,8 @@ class RPCAdapterLoadedResponse:
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest,
RPCUProfileRequest, RPCLoadAdapterRequest]
RPCUProfileRequest, RPCLoadAdapterRequest,
RPCResetPrefixCacheRequest]
REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCAdapterLoadedResponse,
RPCError]

View File

@ -27,8 +27,9 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCAdapterLoadedResponse, RPCError,
RPCLoadAdapterRequest,
RPCProcessRequest, RPCStartupRequest,
RPCStartupResponse,
RPCProcessRequest,
RPCResetPrefixCacheRequest,
RPCStartupRequest, RPCStartupResponse,
RPCUProfileRequest)
from vllm.engine.protocol import EngineClient
# yapf: enable
@ -675,6 +676,13 @@ class MQLLMEngineClient(EngineClient):
await self._send_one_way_rpc_request(
request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket)
async def reset_prefix_cache(self) -> None:
"""Reset the prefix cache"""
await self._send_one_way_rpc_request(
request=RPCResetPrefixCacheRequest.RESET_PREFIX_CACHE,
socket=self.input_socket)
async def add_lora(self, lora_request: LoRARequest) -> None:
"""Load a new LoRA adapter into the engine for future requests."""
# Uses the same I/O as generate requests

View File

@ -16,8 +16,9 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCAdapterLoadedResponse, RPCError,
RPCLoadAdapterRequest,
RPCProcessRequest, RPCStartupRequest,
RPCStartupResponse,
RPCProcessRequest,
RPCResetPrefixCacheRequest,
RPCStartupRequest, RPCStartupResponse,
RPCUProfileRequest)
# yapf: enable
from vllm.logger import init_logger
@ -237,6 +238,8 @@ class MQLLMEngine:
self.stop_profile()
elif isinstance(request, RPCLoadAdapterRequest):
self._handle_load_adapter_request(request)
elif isinstance(request, RPCResetPrefixCacheRequest):
self.reset_prefix_cache()
else:
raise ValueError("Unknown RPCRequest Type: "
f"{type(request)}")
@ -361,6 +364,9 @@ class MQLLMEngine:
def stop_profile(self) -> None:
self.engine.stop_profile()
def reset_prefix_cache(self) -> bool:
return self.engine.reset_prefix_cache()
def signal_handler(*_) -> None:
raise KeyboardInterrupt("MQLLMEngine terminated")

View File

@ -271,6 +271,11 @@ class EngineClient(ABC):
"""Start profiling the engine"""
...
@abstractmethod
async def reset_prefix_cache(self) -> None:
"""Reset the prefix cache"""
...
@abstractmethod
async def add_lora(self, lora_request: LoRARequest) -> None:
"""Load a new LoRA adapter into the engine for future requests."""

View File

@ -1132,6 +1132,9 @@ class LLM:
def stop_profile(self) -> None:
self.llm_engine.stop_profile()
def reset_prefix_cache(self) -> bool:
return self.llm_engine.reset_prefix_cache()
def sleep(self, level: int = 1):
"""
Put the engine to sleep. The engine should not process any requests.
@ -1150,6 +1153,7 @@ class LLM:
where previous model weights are not needed. It reduces CPU memory
pressure.
"""
self.reset_prefix_cache()
self.llm_engine.sleep(level=level)
def wake_up(self):

View File

@ -518,6 +518,18 @@ TASK_HANDLERS: Dict[str, Dict[str, tuple]] = {
},
}
if envs.VLLM_SERVER_DEV_MODE:
@router.post("/reset_prefix_cache")
async def reset_prefix_cache(raw_request: Request):
"""
Reset the prefix cache. Note that we currently do not check if the
prefix cache is successfully reset in the API server.
"""
logger.info("Resetting prefix cache...")
await engine_client(raw_request).reset_prefix_cache()
return Response(status_code=200)
@router.post("/invocations")
async def invocations(raw_request: Request):

View File

@ -72,6 +72,7 @@ if TYPE_CHECKING:
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
VLLM_DISABLE_COMPILE_CACHE: bool = False
VLLM_SERVER_DEV_MODE: bool = False
def get_default_cache_root():
@ -467,6 +468,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda: float(os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1")),
"VLLM_DISABLE_COMPILE_CACHE":
lambda: bool(int(os.getenv("VLLM_DISABLE_COMPILE_CACHE", "0"))),
# If set, vllm will run in development mode, which will enable
# some additional endpoints for developing and debugging,
# e.g. `/reset_prefix_cache`
"VLLM_SERVER_DEV_MODE":
lambda: bool(int(os.getenv("VLLM_SERVER_DEV_MODE", "0"))),
}
# end-env-vars-definition

View File

@ -194,11 +194,6 @@ class ExecutorBase(ABC):
self.collective_rpc("stop_profile")
def sleep(self, level: int = 1):
if self.cache_config.enable_prefix_caching:
# TODO: support sleep with prefix caching
# by resetting the prefix cache state,
# after https://github.com/vllm-project/vllm/pull/12284
raise ValueError("Cannot sleep when prefix caching is enabled.")
self.collective_rpc("sleep", kwargs=dict(level=level))
def wake_up(self):

View File

@ -285,6 +285,33 @@ class KVCacheManager:
if block.ref_cnt == 0:
self.free_block_queue.append(block)
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache. This function may be used in RLHF
flows to invalid prefix caching after the weights are updated,
or used for resetting prefix caching status for benchmarking.
Returns:
bool: True if the prefix cache is successfully reset,
False otherwise.
"""
num_used_blocks = (self.num_gpu_blocks -
self.free_block_queue.num_free_blocks)
if num_used_blocks > 0:
logger.warning(
"Failed to reset prefix cache because some "
"blocks (%d) are not freed yet", num_used_blocks)
return False
# Remove all hashes so that no new blocks will hit.
self.cached_block_hash_to_block = defaultdict(dict)
# Remove all hashes from all blocks.
for block in self.block_pool:
block.reset_hash()
logger.info("Successfully reset prefix cache")
return True
def get_num_common_prefix_blocks(
self,
request: Request,

View File

@ -529,6 +529,9 @@ class Scheduler:
def has_unfinished_requests(self) -> bool:
return self.get_num_unfinished_requests() > 0
def reset_prefix_cache(self) -> bool:
return self.kv_cache_manager.reset_prefix_cache()
def make_stats(self) -> SchedulerStats:
return SchedulerStats(
num_running_reqs=len(self.running),

View File

@ -66,6 +66,11 @@ class EngineCoreProfile:
is_start: bool
@dataclass
class EngineCoreResetPrefixCache:
pass
class EngineCoreRequestType(enum.Enum):
"""
Request types defined as hex byte strings, so it can be sent over sockets
@ -74,6 +79,8 @@ class EngineCoreRequestType(enum.Enum):
ADD = b'\x00'
ABORT = b'\x01'
PROFILE = b'\x02'
RESET_PREFIX_CACHE = b'\x03'
EngineCoreRequestUnion = Union[EngineCoreRequest, EngineCoreProfile, List[str]]
EngineCoreRequestUnion = Union[EngineCoreRequest, EngineCoreProfile,
EngineCoreResetPrefixCache, List[str]]

View File

@ -321,6 +321,9 @@ class AsyncLLM(EngineClient):
async def stop_profile(self) -> None:
await self.engine_core.profile_async(False)
async def reset_prefix_cache(self) -> None:
await self.engine_core.reset_prefix_cache_async()
@property
def is_running(self) -> bool:
return True

View File

@ -20,7 +20,7 @@ from vllm.v1.core.kv_cache_utils import get_kv_cache_config
from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
EngineCoreRequest, EngineCoreRequestType,
EngineCoreRequestUnion)
EngineCoreRequestUnion, EngineCoreResetPrefixCache)
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
from vllm.v1.executor.abstract import Executor
from vllm.v1.request import Request, RequestStatus
@ -135,6 +135,9 @@ class EngineCore:
def profile(self, is_start: bool = True):
self.model_executor.profile(is_start)
def reset_prefix_cache(self):
self.scheduler.reset_prefix_cache()
class EngineCoreProc(EngineCore):
"""ZMQ-wrapper for running EngineCore in background process."""
@ -247,6 +250,8 @@ class EngineCoreProc(EngineCore):
self.add_request(request)
elif isinstance(request, EngineCoreProfile):
self.model_executor.profile(request.is_start)
elif isinstance(request, EngineCoreResetPrefixCache):
self.reset_prefix_cache()
else:
# TODO: make an EngineCoreAbort wrapper
assert isinstance(request, list)
@ -271,7 +276,9 @@ class EngineCoreProc(EngineCore):
request = decoder_add_req.decode(request_data)
elif request_type == EngineCoreRequestType.ABORT.value:
request = decoder_abort_req.decode(request_data)
elif request_type == EngineCoreRequestType.PROFILE.value:
elif request_type in (
EngineCoreRequestType.PROFILE.value,
EngineCoreRequestType.RESET_PREFIX_CACHE.value):
request = pickle.loads(request_data)
else:
raise ValueError(f"Unknown RequestType: {request_type}")

View File

@ -14,7 +14,7 @@ from vllm.utils import (get_open_zmq_ipc_path, kill_process_tree,
make_zmq_socket)
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
EngineCoreRequest, EngineCoreRequestType,
EngineCoreRequestUnion)
EngineCoreRequestUnion, EngineCoreResetPrefixCache)
from vllm.v1.engine.core import EngineCore, EngineCoreProc
from vllm.v1.executor.abstract import Executor
from vllm.v1.serial_utils import PickleEncoder
@ -69,6 +69,9 @@ class EngineCoreClient(ABC):
def profile(self, is_start: bool = True) -> None:
raise NotImplementedError
def reset_prefix_cache(self) -> None:
raise NotImplementedError
def abort_requests(self, request_ids: List[str]) -> None:
raise NotImplementedError
@ -81,6 +84,9 @@ class EngineCoreClient(ABC):
async def profile_async(self, is_start: bool = True) -> None:
raise NotImplementedError
async def reset_prefix_cache_async(self) -> None:
raise NotImplementedError
async def abort_requests_async(self, request_ids: List[str]) -> None:
raise NotImplementedError
@ -108,12 +114,15 @@ class InprocClient(EngineCoreClient):
if len(request_ids) > 0:
self.engine_core.abort_requests(request_ids)
def shutdown(self):
def shutdown(self) -> None:
self.engine_core.shutdown()
def profile(self, is_start: bool = True) -> None:
self.engine_core.profile(is_start)
def reset_prefix_cache(self) -> None:
self.engine_core.reset_prefix_cache()
class MPClient(EngineCoreClient):
"""
@ -229,6 +238,10 @@ class SyncMPClient(MPClient):
self._send_input(EngineCoreRequestType.PROFILE,
EngineCoreProfile(is_start))
def reset_prefix_cache(self) -> None:
self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE,
EngineCoreResetPrefixCache())
class AsyncMPClient(MPClient):
"""Asyncio-compatible client for multi-proc EngineCore."""
@ -266,3 +279,7 @@ class AsyncMPClient(MPClient):
async def profile_async(self, is_start: bool = True) -> None:
await self._send_input(EngineCoreRequestType.PROFILE,
EngineCoreProfile(is_start))
async def reset_prefix_cache_async(self) -> None:
await self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE,
EngineCoreResetPrefixCache())

View File

@ -162,6 +162,9 @@ class LLMEngine:
def stop_profile(self):
self.engine_core.profile(False)
def reset_prefix_cache(self):
self.engine_core.reset_prefix_cache()
def get_tokenizer_group(
self,
group_type: Type[_G] = BaseTokenizerGroup,