[Core] Support reset_prefix_cache (#12284)

This commit is contained in:
Cody Yu 2025-01-22 10:52:27 -08:00 committed by GitHub
parent 96f6a7596f
commit 7206ce4ce1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 300 additions and 21 deletions

View File

@ -796,6 +796,44 @@ class TestPrefixCachingBlockAllocator:
block_hashes=block_hashes_seq1) block_hashes=block_hashes_seq1)
assert len(cached_blocks) == len(blocks_seq1) - num_evicted_blocks assert len(cached_blocks) == len(blocks_seq1) - num_evicted_blocks
# Test reset prefix cache
@staticmethod
@pytest.mark.parametrize("num_blocks", [10])
@pytest.mark.parametrize("block_size", [16])
def test_reset_prefix_cache(num_blocks: int, block_size: int):
"""This test case simulates the case of resetting the prefix cache."""
allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
block_size=block_size)
token_ids = list(range(3 * block_size))
first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=token_ids,
allocator=allocator,
)
second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=token_ids,
allocator=allocator,
)
# Free each block in the first chain.
for block in first_chain:
allocator.free(block)
# Failed to reset prefix cache because some blocks are not freed yet.
assert not allocator.reset_prefix_cache()
assert allocator.get_prefix_cache_hit_rate() > 0.0
# Free each block in the second chain.
for block in second_chain:
allocator.free(block)
# Reset prefix cache.
assert allocator.reset_prefix_cache()
assert allocator.get_prefix_cache_hit_rate() == 0.0
@staticmethod @staticmethod
def create_immutable_chain( def create_immutable_chain(
block_size: int, block_size: int,

View File

@ -587,3 +587,42 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
assert {block.ref_cnt for block in block_part1[:3]} == {1} assert {block.ref_cnt for block in block_part1[:3]} == {1}
# Block 3-5 are free. # Block 3-5 are free.
assert {block.ref_cnt for block in block_part1[3:]} == {0} assert {block.ref_cnt for block in block_part1[3:]} == {0}
def test_reset_prefix_cache():
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=0,
)
full_block_token_ids = [i for i in range(3) for _ in range(16)]
unique_token_ids = [3] * 7
all_token_ids = full_block_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids)
blocks = manager.allocate_slots(req0, 55, [])
assert [b.block_id for b in blocks] == [0, 1, 2, 3]
unique_token_ids = [4] * 7
all_token_ids = full_block_token_ids + unique_token_ids
req1 = make_request("1", all_token_ids)
computed_blocks, _ = manager.get_computed_blocks(req1)
assert len(req1.kv_block_hashes) == 3
assert len(computed_blocks) == 3
blocks = manager.allocate_slots(req1, 7, computed_blocks)
assert [b.block_id for b in blocks] == [4]
# Failed to reset prefix cache because some blocks are not freed yet.
assert not manager.reset_prefix_cache()
assert manager.cached_block_hash_to_block
# Free the blocks.
manager.free(req0)
manager.free(req1)
assert manager.reset_prefix_cache()
assert not manager.cached_block_hash_to_block
assert all([blk.block_hash is None for blk in manager.block_pool])

View File

@ -339,6 +339,13 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
assert device in self._allocators assert device in self._allocators
return self._allocators[device].get_prefix_cache_hit_rate() return self._allocators[device].get_prefix_cache_hit_rate()
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache for all devices."""
success = True
for allocator in self._allocators.values():
success = success and allocator.reset_prefix_cache()
return success
def get_and_reset_swaps(self) -> List[Tuple[int, int]]: def get_and_reset_swaps(self) -> List[Tuple[int, int]]:
"""Returns and clears the mapping of source to destination block IDs. """Returns and clears the mapping of source to destination block IDs.
Will be called after every swapping operations for now, and after every Will be called after every swapping operations for now, and after every

View File

@ -192,6 +192,11 @@ class BlockAllocator(ABC):
"""Prefix cache hit rate. -1 means not supported or disabled.""" """Prefix cache hit rate. -1 means not supported or disabled."""
pass pass
@abstractmethod
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache."""
pass
class NoFreeBlocksError(ValueError): class NoFreeBlocksError(ValueError):
pass pass
@ -297,6 +302,11 @@ class DeviceAwareBlockAllocator(ABC):
"""Prefix cache hit rate. -1 means not supported or disabled.""" """Prefix cache hit rate. -1 means not supported or disabled."""
pass pass
@abstractmethod
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache."""
pass
@abstractmethod @abstractmethod
def find_cached_blocks_prefix( def find_cached_blocks_prefix(
self, self,

View File

@ -1,5 +1,5 @@
from collections import deque from collections import deque
from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple, Union
from vllm.core.block.common import (BlockPool, CopyOnWriteTracker, RefCounter, from vllm.core.block.common import (BlockPool, CopyOnWriteTracker, RefCounter,
get_all_blocks_recursively) get_all_blocks_recursively)
@ -136,16 +136,18 @@ class NaiveBlockAllocator(BlockAllocator):
self._refcounter.incr(block_id) self._refcounter.incr(block_id)
return block_id return block_id
def _free_block_id(self, block: Block) -> None: def _free_block_id(self, block: Union[Block, BlockId]) -> None:
if isinstance(block, Block):
block_id = block.block_id block_id = block.block_id
block.block_id = None
else:
block_id = block
assert block_id is not None assert block_id is not None
refcount = self._refcounter.decr(block_id) refcount = self._refcounter.decr(block_id)
if refcount == 0: if refcount == 0:
self._free_block_indices.appendleft(block_id) self._free_block_indices.appendleft(block_id)
block.block_id = None
def free(self, block: Block, keep_block_object: bool = False) -> None: def free(self, block: Block, keep_block_object: bool = False) -> None:
# Release the physical block id # Release the physical block id
self._free_block_id(block) self._free_block_id(block)
@ -154,6 +156,9 @@ class NaiveBlockAllocator(BlockAllocator):
if not keep_block_object: if not keep_block_object:
self._block_pool.free_block(block) self._block_pool.free_block(block)
def free_block_id(self, block_id: BlockId) -> None:
self._free_block_id(block_id)
def fork(self, last_block: Block) -> List[Block]: def fork(self, last_block: Block) -> List[Block]:
"""Creates a new sequence of blocks that shares the same underlying """Creates a new sequence of blocks that shares the same underlying
memory as the original sequence. memory as the original sequence.
@ -325,6 +330,10 @@ class NaiveBlockAllocator(BlockAllocator):
def get_prefix_cache_hit_rate(self) -> float: def get_prefix_cache_hit_rate(self) -> float:
return -1 return -1
def reset_prefix_cache(self) -> bool:
"""No prefix cache for naive block allocator."""
return True
def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]: def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]:
# Not applicable for naive block allocator. # Not applicable for naive block allocator.
return [] return []

View File

@ -12,6 +12,7 @@ from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId, Device,
from vllm.core.block.naive_block import (BlockPool, NaiveBlock, from vllm.core.block.naive_block import (BlockPool, NaiveBlock,
NaiveBlockAllocator) NaiveBlockAllocator)
from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor
from vllm.logger import init_logger
from vllm.sequence import Sequence from vllm.sequence import Sequence
PrefixHash = int PrefixHash = int
@ -21,6 +22,8 @@ PrefixHash = int
# then we know this block hasn't been accessed yet. # then we know this block hasn't been accessed yet.
_DEFAULT_LAST_ACCESSED_TIME = -1 _DEFAULT_LAST_ACCESSED_TIME = -1
logger = init_logger(__name__)
class BlockTracker: class BlockTracker:
"""Used to track the status of a block inside the prefix caching allocator """Used to track the status of a block inside the prefix caching allocator
@ -105,7 +108,8 @@ class PrefixCachingBlockAllocator(BlockAllocator):
# Evitor used to maintain how we want to handle those computed blocks # Evitor used to maintain how we want to handle those computed blocks
# if we find memory pressure is high. # if we find memory pressure is high.
self.evictor: Evictor = make_evictor(eviction_policy) self.eviction_policy = eviction_policy
self.evictor: Evictor = make_evictor(self.eviction_policy)
# We share the refcounter between allocators. This allows us to promote # We share the refcounter between allocators. This allows us to promote
# blocks originally allocated in the hashless allocator to immutable # blocks originally allocated in the hashless allocator to immutable
@ -428,6 +432,44 @@ class PrefixCachingBlockAllocator(BlockAllocator):
def get_prefix_cache_hit_rate(self) -> float: def get_prefix_cache_hit_rate(self) -> float:
return self.metric_data.get_hit_rate() return self.metric_data.get_hit_rate()
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache. This function may be used in RLHF
flows to invalid prefix caching after the weights are updated,
or used for resetting prefix caching status for benchmarking.
Returns:
bool: True if the prefix cache is successfully reset,
False otherwise.
"""
num_used_blocks = (self.get_num_total_blocks() -
self.get_num_free_blocks())
if num_used_blocks > 0:
logger.warning(
"Failed to reset prefix cache because some "
"blocks (%d) are not freed yet", num_used_blocks)
return False
# Free all blocks in the evictor.
while (block_id :=
self._maybe_allocate_evicted_block_id()) is not None:
self._hashless_allocator.free_block_id(block_id)
# Should not have any cached blocks because all blocks are evicted.
assert not self._cached_blocks
# Reset the evictor.
self.evictor = make_evictor(self.eviction_policy)
# Reset the block tracker.
for block_id in self._block_tracker:
self._block_tracker[block_id] = BlockTracker()
# Reset the metrics.
self.metric_data = CacheMetricData()
logger.info("Successfully reset prefix cache")
return True
def is_block_cached(self, block: Block) -> bool: def is_block_cached(self, block: Block) -> bool:
assert block.content_hash is not None assert block.content_hash is not None
return block.content_hash in self._cached_blocks return block.content_hash in self._cached_blocks

View File

@ -455,6 +455,9 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
def get_prefix_cache_hit_rate(self, device: Device) -> float: def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_allocator.get_prefix_cache_hit_rate(device) return self.block_allocator.get_prefix_cache_hit_rate(device)
def reset_prefix_cache(self) -> bool:
return self.block_allocator.reset_prefix_cache()
def _can_swap(self, def _can_swap(self,
seq_group: SequenceGroup, seq_group: SequenceGroup,
device: Device, device: Device,

View File

@ -122,6 +122,11 @@ class BlockSpaceManager(ABC):
"""Prefix cache hit rate. -1 means not supported or disabled.""" """Prefix cache hit rate. -1 means not supported or disabled."""
pass pass
@abstractmethod
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache for all devices."""
pass
@abstractmethod @abstractmethod
def get_num_cached_tokens(self, seq: Sequence) -> int: def get_num_cached_tokens(self, seq: Sequence) -> int:
pass pass

View File

@ -90,5 +90,8 @@ class PlaceholderBlockSpaceManager(BlockSpaceManager):
def get_prefix_cache_hit_rate(self, device: Device) -> float: def get_prefix_cache_hit_rate(self, device: Device) -> float:
return -1 return -1
def reset_prefix_cache(self) -> bool:
return True
def get_num_cached_tokens(self, seq: Sequence) -> int: def get_num_cached_tokens(self, seq: Sequence) -> int:
return 0 return 0

View File

@ -504,6 +504,9 @@ class Scheduler:
def get_prefix_cache_hit_rate(self, device: Device) -> float: def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_manager.get_prefix_cache_hit_rate(device) return self.block_manager.get_prefix_cache_hit_rate(device)
def reset_prefix_cache(self) -> bool:
return self.block_manager.reset_prefix_cache()
def get_num_unfinished_seq_groups(self) -> int: def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped) return len(self.waiting) + len(self.running) + len(self.swapped)

View File

@ -1182,6 +1182,9 @@ class AsyncLLMEngine(EngineClient):
async def stop_profile(self) -> None: async def stop_profile(self) -> None:
self.engine.stop_profile() self.engine.stop_profile()
async def reset_prefix_cache(self) -> None:
self.engine.reset_prefix_cache()
async def add_lora(self, lora_request: LoRARequest) -> None: async def add_lora(self, lora_request: LoRARequest) -> None:
self.engine.add_lora(lora_request) self.engine.add_lora(lora_request)

View File

@ -914,6 +914,14 @@ class LLMEngine:
""" """
return self.scheduler[virtual_engine].has_unfinished_seqs() return self.scheduler[virtual_engine].has_unfinished_seqs()
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache for all devices."""
success = True
for scheduler in self.scheduler:
success = success and scheduler.reset_prefix_cache()
return success
@staticmethod @staticmethod
def _process_sequence_group_outputs( def _process_sequence_group_outputs(
seq_group: SequenceGroup, seq_group: SequenceGroup,

View File

@ -121,6 +121,10 @@ class RPCUProfileRequest(Enum):
STOP_PROFILE = 2 STOP_PROFILE = 2
class RPCResetPrefixCacheRequest(Enum):
RESET_PREFIX_CACHE = 1
@dataclass @dataclass
class RPCLoadAdapterRequest: class RPCLoadAdapterRequest:
lora_request: LoRARequest lora_request: LoRARequest
@ -134,7 +138,8 @@ class RPCAdapterLoadedResponse:
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest, RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest,
RPCUProfileRequest, RPCLoadAdapterRequest] RPCUProfileRequest, RPCLoadAdapterRequest,
RPCResetPrefixCacheRequest]
REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCAdapterLoadedResponse, REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCAdapterLoadedResponse,
RPCError] RPCError]

View File

@ -27,8 +27,9 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest, VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCAdapterLoadedResponse, RPCError, RPCAdapterLoadedResponse, RPCError,
RPCLoadAdapterRequest, RPCLoadAdapterRequest,
RPCProcessRequest, RPCStartupRequest, RPCProcessRequest,
RPCStartupResponse, RPCResetPrefixCacheRequest,
RPCStartupRequest, RPCStartupResponse,
RPCUProfileRequest) RPCUProfileRequest)
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
# yapf: enable # yapf: enable
@ -675,6 +676,13 @@ class MQLLMEngineClient(EngineClient):
await self._send_one_way_rpc_request( await self._send_one_way_rpc_request(
request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket) request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket)
async def reset_prefix_cache(self) -> None:
"""Reset the prefix cache"""
await self._send_one_way_rpc_request(
request=RPCResetPrefixCacheRequest.RESET_PREFIX_CACHE,
socket=self.input_socket)
async def add_lora(self, lora_request: LoRARequest) -> None: async def add_lora(self, lora_request: LoRARequest) -> None:
"""Load a new LoRA adapter into the engine for future requests.""" """Load a new LoRA adapter into the engine for future requests."""
# Uses the same I/O as generate requests # Uses the same I/O as generate requests

View File

@ -16,8 +16,9 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest, VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCAdapterLoadedResponse, RPCError, RPCAdapterLoadedResponse, RPCError,
RPCLoadAdapterRequest, RPCLoadAdapterRequest,
RPCProcessRequest, RPCStartupRequest, RPCProcessRequest,
RPCStartupResponse, RPCResetPrefixCacheRequest,
RPCStartupRequest, RPCStartupResponse,
RPCUProfileRequest) RPCUProfileRequest)
# yapf: enable # yapf: enable
from vllm.logger import init_logger from vllm.logger import init_logger
@ -237,6 +238,8 @@ class MQLLMEngine:
self.stop_profile() self.stop_profile()
elif isinstance(request, RPCLoadAdapterRequest): elif isinstance(request, RPCLoadAdapterRequest):
self._handle_load_adapter_request(request) self._handle_load_adapter_request(request)
elif isinstance(request, RPCResetPrefixCacheRequest):
self.reset_prefix_cache()
else: else:
raise ValueError("Unknown RPCRequest Type: " raise ValueError("Unknown RPCRequest Type: "
f"{type(request)}") f"{type(request)}")
@ -361,6 +364,9 @@ class MQLLMEngine:
def stop_profile(self) -> None: def stop_profile(self) -> None:
self.engine.stop_profile() self.engine.stop_profile()
def reset_prefix_cache(self) -> bool:
return self.engine.reset_prefix_cache()
def signal_handler(*_) -> None: def signal_handler(*_) -> None:
raise KeyboardInterrupt("MQLLMEngine terminated") raise KeyboardInterrupt("MQLLMEngine terminated")

View File

@ -271,6 +271,11 @@ class EngineClient(ABC):
"""Start profiling the engine""" """Start profiling the engine"""
... ...
@abstractmethod
async def reset_prefix_cache(self) -> None:
"""Reset the prefix cache"""
...
@abstractmethod @abstractmethod
async def add_lora(self, lora_request: LoRARequest) -> None: async def add_lora(self, lora_request: LoRARequest) -> None:
"""Load a new LoRA adapter into the engine for future requests.""" """Load a new LoRA adapter into the engine for future requests."""

View File

@ -1132,6 +1132,9 @@ class LLM:
def stop_profile(self) -> None: def stop_profile(self) -> None:
self.llm_engine.stop_profile() self.llm_engine.stop_profile()
def reset_prefix_cache(self) -> bool:
return self.llm_engine.reset_prefix_cache()
def sleep(self, level: int = 1): def sleep(self, level: int = 1):
""" """
Put the engine to sleep. The engine should not process any requests. Put the engine to sleep. The engine should not process any requests.
@ -1150,6 +1153,7 @@ class LLM:
where previous model weights are not needed. It reduces CPU memory where previous model weights are not needed. It reduces CPU memory
pressure. pressure.
""" """
self.reset_prefix_cache()
self.llm_engine.sleep(level=level) self.llm_engine.sleep(level=level)
def wake_up(self): def wake_up(self):

View File

@ -518,6 +518,18 @@ TASK_HANDLERS: Dict[str, Dict[str, tuple]] = {
}, },
} }
if envs.VLLM_SERVER_DEV_MODE:
@router.post("/reset_prefix_cache")
async def reset_prefix_cache(raw_request: Request):
"""
Reset the prefix cache. Note that we currently do not check if the
prefix cache is successfully reset in the API server.
"""
logger.info("Resetting prefix cache...")
await engine_client(raw_request).reset_prefix_cache()
return Response(status_code=200)
@router.post("/invocations") @router.post("/invocations")
async def invocations(raw_request: Request): async def invocations(raw_request: Request):

View File

@ -72,6 +72,7 @@ if TYPE_CHECKING:
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
VLLM_DISABLE_COMPILE_CACHE: bool = False VLLM_DISABLE_COMPILE_CACHE: bool = False
VLLM_SERVER_DEV_MODE: bool = False
def get_default_cache_root(): def get_default_cache_root():
@ -467,6 +468,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda: float(os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1")), lambda: float(os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1")),
"VLLM_DISABLE_COMPILE_CACHE": "VLLM_DISABLE_COMPILE_CACHE":
lambda: bool(int(os.getenv("VLLM_DISABLE_COMPILE_CACHE", "0"))), lambda: bool(int(os.getenv("VLLM_DISABLE_COMPILE_CACHE", "0"))),
# If set, vllm will run in development mode, which will enable
# some additional endpoints for developing and debugging,
# e.g. `/reset_prefix_cache`
"VLLM_SERVER_DEV_MODE":
lambda: bool(int(os.getenv("VLLM_SERVER_DEV_MODE", "0"))),
} }
# end-env-vars-definition # end-env-vars-definition

View File

@ -194,11 +194,6 @@ class ExecutorBase(ABC):
self.collective_rpc("stop_profile") self.collective_rpc("stop_profile")
def sleep(self, level: int = 1): def sleep(self, level: int = 1):
if self.cache_config.enable_prefix_caching:
# TODO: support sleep with prefix caching
# by resetting the prefix cache state,
# after https://github.com/vllm-project/vllm/pull/12284
raise ValueError("Cannot sleep when prefix caching is enabled.")
self.collective_rpc("sleep", kwargs=dict(level=level)) self.collective_rpc("sleep", kwargs=dict(level=level))
def wake_up(self): def wake_up(self):

View File

@ -285,6 +285,33 @@ class KVCacheManager:
if block.ref_cnt == 0: if block.ref_cnt == 0:
self.free_block_queue.append(block) self.free_block_queue.append(block)
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache. This function may be used in RLHF
flows to invalid prefix caching after the weights are updated,
or used for resetting prefix caching status for benchmarking.
Returns:
bool: True if the prefix cache is successfully reset,
False otherwise.
"""
num_used_blocks = (self.num_gpu_blocks -
self.free_block_queue.num_free_blocks)
if num_used_blocks > 0:
logger.warning(
"Failed to reset prefix cache because some "
"blocks (%d) are not freed yet", num_used_blocks)
return False
# Remove all hashes so that no new blocks will hit.
self.cached_block_hash_to_block = defaultdict(dict)
# Remove all hashes from all blocks.
for block in self.block_pool:
block.reset_hash()
logger.info("Successfully reset prefix cache")
return True
def get_num_common_prefix_blocks( def get_num_common_prefix_blocks(
self, self,
request: Request, request: Request,

View File

@ -529,6 +529,9 @@ class Scheduler:
def has_unfinished_requests(self) -> bool: def has_unfinished_requests(self) -> bool:
return self.get_num_unfinished_requests() > 0 return self.get_num_unfinished_requests() > 0
def reset_prefix_cache(self) -> bool:
return self.kv_cache_manager.reset_prefix_cache()
def make_stats(self) -> SchedulerStats: def make_stats(self) -> SchedulerStats:
return SchedulerStats( return SchedulerStats(
num_running_reqs=len(self.running), num_running_reqs=len(self.running),

View File

@ -66,6 +66,11 @@ class EngineCoreProfile:
is_start: bool is_start: bool
@dataclass
class EngineCoreResetPrefixCache:
pass
class EngineCoreRequestType(enum.Enum): class EngineCoreRequestType(enum.Enum):
""" """
Request types defined as hex byte strings, so it can be sent over sockets Request types defined as hex byte strings, so it can be sent over sockets
@ -74,6 +79,8 @@ class EngineCoreRequestType(enum.Enum):
ADD = b'\x00' ADD = b'\x00'
ABORT = b'\x01' ABORT = b'\x01'
PROFILE = b'\x02' PROFILE = b'\x02'
RESET_PREFIX_CACHE = b'\x03'
EngineCoreRequestUnion = Union[EngineCoreRequest, EngineCoreProfile, List[str]] EngineCoreRequestUnion = Union[EngineCoreRequest, EngineCoreProfile,
EngineCoreResetPrefixCache, List[str]]

View File

@ -321,6 +321,9 @@ class AsyncLLM(EngineClient):
async def stop_profile(self) -> None: async def stop_profile(self) -> None:
await self.engine_core.profile_async(False) await self.engine_core.profile_async(False)
async def reset_prefix_cache(self) -> None:
await self.engine_core.reset_prefix_cache_async()
@property @property
def is_running(self) -> bool: def is_running(self) -> bool:
return True return True

View File

@ -20,7 +20,7 @@ from vllm.v1.core.kv_cache_utils import get_kv_cache_config
from vllm.v1.core.scheduler import Scheduler from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile, from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
EngineCoreRequest, EngineCoreRequestType, EngineCoreRequest, EngineCoreRequestType,
EngineCoreRequestUnion) EngineCoreRequestUnion, EngineCoreResetPrefixCache)
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
@ -135,6 +135,9 @@ class EngineCore:
def profile(self, is_start: bool = True): def profile(self, is_start: bool = True):
self.model_executor.profile(is_start) self.model_executor.profile(is_start)
def reset_prefix_cache(self):
self.scheduler.reset_prefix_cache()
class EngineCoreProc(EngineCore): class EngineCoreProc(EngineCore):
"""ZMQ-wrapper for running EngineCore in background process.""" """ZMQ-wrapper for running EngineCore in background process."""
@ -247,6 +250,8 @@ class EngineCoreProc(EngineCore):
self.add_request(request) self.add_request(request)
elif isinstance(request, EngineCoreProfile): elif isinstance(request, EngineCoreProfile):
self.model_executor.profile(request.is_start) self.model_executor.profile(request.is_start)
elif isinstance(request, EngineCoreResetPrefixCache):
self.reset_prefix_cache()
else: else:
# TODO: make an EngineCoreAbort wrapper # TODO: make an EngineCoreAbort wrapper
assert isinstance(request, list) assert isinstance(request, list)
@ -271,7 +276,9 @@ class EngineCoreProc(EngineCore):
request = decoder_add_req.decode(request_data) request = decoder_add_req.decode(request_data)
elif request_type == EngineCoreRequestType.ABORT.value: elif request_type == EngineCoreRequestType.ABORT.value:
request = decoder_abort_req.decode(request_data) request = decoder_abort_req.decode(request_data)
elif request_type == EngineCoreRequestType.PROFILE.value: elif request_type in (
EngineCoreRequestType.PROFILE.value,
EngineCoreRequestType.RESET_PREFIX_CACHE.value):
request = pickle.loads(request_data) request = pickle.loads(request_data)
else: else:
raise ValueError(f"Unknown RequestType: {request_type}") raise ValueError(f"Unknown RequestType: {request_type}")

View File

@ -14,7 +14,7 @@ from vllm.utils import (get_open_zmq_ipc_path, kill_process_tree,
make_zmq_socket) make_zmq_socket)
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile, from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
EngineCoreRequest, EngineCoreRequestType, EngineCoreRequest, EngineCoreRequestType,
EngineCoreRequestUnion) EngineCoreRequestUnion, EngineCoreResetPrefixCache)
from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.engine.core import EngineCore, EngineCoreProc
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.serial_utils import PickleEncoder from vllm.v1.serial_utils import PickleEncoder
@ -69,6 +69,9 @@ class EngineCoreClient(ABC):
def profile(self, is_start: bool = True) -> None: def profile(self, is_start: bool = True) -> None:
raise NotImplementedError raise NotImplementedError
def reset_prefix_cache(self) -> None:
raise NotImplementedError
def abort_requests(self, request_ids: List[str]) -> None: def abort_requests(self, request_ids: List[str]) -> None:
raise NotImplementedError raise NotImplementedError
@ -81,6 +84,9 @@ class EngineCoreClient(ABC):
async def profile_async(self, is_start: bool = True) -> None: async def profile_async(self, is_start: bool = True) -> None:
raise NotImplementedError raise NotImplementedError
async def reset_prefix_cache_async(self) -> None:
raise NotImplementedError
async def abort_requests_async(self, request_ids: List[str]) -> None: async def abort_requests_async(self, request_ids: List[str]) -> None:
raise NotImplementedError raise NotImplementedError
@ -108,12 +114,15 @@ class InprocClient(EngineCoreClient):
if len(request_ids) > 0: if len(request_ids) > 0:
self.engine_core.abort_requests(request_ids) self.engine_core.abort_requests(request_ids)
def shutdown(self): def shutdown(self) -> None:
self.engine_core.shutdown() self.engine_core.shutdown()
def profile(self, is_start: bool = True) -> None: def profile(self, is_start: bool = True) -> None:
self.engine_core.profile(is_start) self.engine_core.profile(is_start)
def reset_prefix_cache(self) -> None:
self.engine_core.reset_prefix_cache()
class MPClient(EngineCoreClient): class MPClient(EngineCoreClient):
""" """
@ -229,6 +238,10 @@ class SyncMPClient(MPClient):
self._send_input(EngineCoreRequestType.PROFILE, self._send_input(EngineCoreRequestType.PROFILE,
EngineCoreProfile(is_start)) EngineCoreProfile(is_start))
def reset_prefix_cache(self) -> None:
self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE,
EngineCoreResetPrefixCache())
class AsyncMPClient(MPClient): class AsyncMPClient(MPClient):
"""Asyncio-compatible client for multi-proc EngineCore.""" """Asyncio-compatible client for multi-proc EngineCore."""
@ -266,3 +279,7 @@ class AsyncMPClient(MPClient):
async def profile_async(self, is_start: bool = True) -> None: async def profile_async(self, is_start: bool = True) -> None:
await self._send_input(EngineCoreRequestType.PROFILE, await self._send_input(EngineCoreRequestType.PROFILE,
EngineCoreProfile(is_start)) EngineCoreProfile(is_start))
async def reset_prefix_cache_async(self) -> None:
await self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE,
EngineCoreResetPrefixCache())

View File

@ -162,6 +162,9 @@ class LLMEngine:
def stop_profile(self): def stop_profile(self):
self.engine_core.profile(False) self.engine_core.profile(False)
def reset_prefix_cache(self):
self.engine_core.reset_prefix_cache()
def get_tokenizer_group( def get_tokenizer_group(
self, self,
group_type: Type[_G] = BaseTokenizerGroup, group_type: Type[_G] = BaseTokenizerGroup,