mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 22:55:51 +08:00
[Core] Support reset_prefix_cache (#12284)
This commit is contained in:
parent
96f6a7596f
commit
7206ce4ce1
@ -796,6 +796,44 @@ class TestPrefixCachingBlockAllocator:
|
||||
block_hashes=block_hashes_seq1)
|
||||
assert len(cached_blocks) == len(blocks_seq1) - num_evicted_blocks
|
||||
|
||||
# Test reset prefix cache
|
||||
@staticmethod
|
||||
@pytest.mark.parametrize("num_blocks", [10])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
def test_reset_prefix_cache(num_blocks: int, block_size: int):
|
||||
"""This test case simulates the case of resetting the prefix cache."""
|
||||
|
||||
allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
|
||||
block_size=block_size)
|
||||
token_ids = list(range(3 * block_size))
|
||||
|
||||
first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
|
||||
block_size=block_size,
|
||||
token_ids=token_ids,
|
||||
allocator=allocator,
|
||||
)
|
||||
second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
|
||||
block_size=block_size,
|
||||
token_ids=token_ids,
|
||||
allocator=allocator,
|
||||
)
|
||||
|
||||
# Free each block in the first chain.
|
||||
for block in first_chain:
|
||||
allocator.free(block)
|
||||
|
||||
# Failed to reset prefix cache because some blocks are not freed yet.
|
||||
assert not allocator.reset_prefix_cache()
|
||||
assert allocator.get_prefix_cache_hit_rate() > 0.0
|
||||
|
||||
# Free each block in the second chain.
|
||||
for block in second_chain:
|
||||
allocator.free(block)
|
||||
|
||||
# Reset prefix cache.
|
||||
assert allocator.reset_prefix_cache()
|
||||
assert allocator.get_prefix_cache_hit_rate() == 0.0
|
||||
|
||||
@staticmethod
|
||||
def create_immutable_chain(
|
||||
block_size: int,
|
||||
|
||||
@ -587,3 +587,42 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
assert {block.ref_cnt for block in block_part1[:3]} == {1}
|
||||
# Block 3-5 are free.
|
||||
assert {block.ref_cnt for block in block_part1[3:]} == {0}
|
||||
|
||||
|
||||
def test_reset_prefix_cache():
|
||||
manager = KVCacheManager(
|
||||
block_size=16,
|
||||
num_gpu_blocks=10,
|
||||
max_model_len=8192,
|
||||
sliding_window=None,
|
||||
enable_caching=True,
|
||||
num_preallocate_tokens=0,
|
||||
)
|
||||
|
||||
full_block_token_ids = [i for i in range(3) for _ in range(16)]
|
||||
unique_token_ids = [3] * 7
|
||||
all_token_ids = full_block_token_ids + unique_token_ids
|
||||
req0 = make_request("0", all_token_ids)
|
||||
blocks = manager.allocate_slots(req0, 55, [])
|
||||
assert [b.block_id for b in blocks] == [0, 1, 2, 3]
|
||||
|
||||
unique_token_ids = [4] * 7
|
||||
all_token_ids = full_block_token_ids + unique_token_ids
|
||||
req1 = make_request("1", all_token_ids)
|
||||
computed_blocks, _ = manager.get_computed_blocks(req1)
|
||||
assert len(req1.kv_block_hashes) == 3
|
||||
assert len(computed_blocks) == 3
|
||||
blocks = manager.allocate_slots(req1, 7, computed_blocks)
|
||||
assert [b.block_id for b in blocks] == [4]
|
||||
|
||||
# Failed to reset prefix cache because some blocks are not freed yet.
|
||||
assert not manager.reset_prefix_cache()
|
||||
assert manager.cached_block_hash_to_block
|
||||
|
||||
# Free the blocks.
|
||||
manager.free(req0)
|
||||
manager.free(req1)
|
||||
|
||||
assert manager.reset_prefix_cache()
|
||||
assert not manager.cached_block_hash_to_block
|
||||
assert all([blk.block_hash is None for blk in manager.block_pool])
|
||||
|
||||
@ -339,6 +339,13 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
|
||||
assert device in self._allocators
|
||||
return self._allocators[device].get_prefix_cache_hit_rate()
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
"""Reset prefix cache for all devices."""
|
||||
success = True
|
||||
for allocator in self._allocators.values():
|
||||
success = success and allocator.reset_prefix_cache()
|
||||
return success
|
||||
|
||||
def get_and_reset_swaps(self) -> List[Tuple[int, int]]:
|
||||
"""Returns and clears the mapping of source to destination block IDs.
|
||||
Will be called after every swapping operations for now, and after every
|
||||
|
||||
@ -192,6 +192,11 @@ class BlockAllocator(ABC):
|
||||
"""Prefix cache hit rate. -1 means not supported or disabled."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
"""Reset prefix cache."""
|
||||
pass
|
||||
|
||||
class NoFreeBlocksError(ValueError):
|
||||
pass
|
||||
|
||||
@ -297,6 +302,11 @@ class DeviceAwareBlockAllocator(ABC):
|
||||
"""Prefix cache hit rate. -1 means not supported or disabled."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
"""Reset prefix cache."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def find_cached_blocks_prefix(
|
||||
self,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections import deque
|
||||
from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple
|
||||
from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
from vllm.core.block.common import (BlockPool, CopyOnWriteTracker, RefCounter,
|
||||
get_all_blocks_recursively)
|
||||
@ -136,16 +136,18 @@ class NaiveBlockAllocator(BlockAllocator):
|
||||
self._refcounter.incr(block_id)
|
||||
return block_id
|
||||
|
||||
def _free_block_id(self, block: Block) -> None:
|
||||
block_id = block.block_id
|
||||
def _free_block_id(self, block: Union[Block, BlockId]) -> None:
|
||||
if isinstance(block, Block):
|
||||
block_id = block.block_id
|
||||
block.block_id = None
|
||||
else:
|
||||
block_id = block
|
||||
assert block_id is not None
|
||||
|
||||
refcount = self._refcounter.decr(block_id)
|
||||
if refcount == 0:
|
||||
self._free_block_indices.appendleft(block_id)
|
||||
|
||||
block.block_id = None
|
||||
|
||||
def free(self, block: Block, keep_block_object: bool = False) -> None:
|
||||
# Release the physical block id
|
||||
self._free_block_id(block)
|
||||
@ -154,6 +156,9 @@ class NaiveBlockAllocator(BlockAllocator):
|
||||
if not keep_block_object:
|
||||
self._block_pool.free_block(block)
|
||||
|
||||
def free_block_id(self, block_id: BlockId) -> None:
|
||||
self._free_block_id(block_id)
|
||||
|
||||
def fork(self, last_block: Block) -> List[Block]:
|
||||
"""Creates a new sequence of blocks that shares the same underlying
|
||||
memory as the original sequence.
|
||||
@ -325,6 +330,10 @@ class NaiveBlockAllocator(BlockAllocator):
|
||||
def get_prefix_cache_hit_rate(self) -> float:
|
||||
return -1
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
"""No prefix cache for naive block allocator."""
|
||||
return True
|
||||
|
||||
def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]:
|
||||
# Not applicable for naive block allocator.
|
||||
return []
|
||||
|
||||
@ -12,6 +12,7 @@ from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId, Device,
|
||||
from vllm.core.block.naive_block import (BlockPool, NaiveBlock,
|
||||
NaiveBlockAllocator)
|
||||
from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import Sequence
|
||||
|
||||
PrefixHash = int
|
||||
@ -21,6 +22,8 @@ PrefixHash = int
|
||||
# then we know this block hasn't been accessed yet.
|
||||
_DEFAULT_LAST_ACCESSED_TIME = -1
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BlockTracker:
|
||||
"""Used to track the status of a block inside the prefix caching allocator
|
||||
@ -105,7 +108,8 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
||||
|
||||
# Evitor used to maintain how we want to handle those computed blocks
|
||||
# if we find memory pressure is high.
|
||||
self.evictor: Evictor = make_evictor(eviction_policy)
|
||||
self.eviction_policy = eviction_policy
|
||||
self.evictor: Evictor = make_evictor(self.eviction_policy)
|
||||
|
||||
# We share the refcounter between allocators. This allows us to promote
|
||||
# blocks originally allocated in the hashless allocator to immutable
|
||||
@ -428,6 +432,44 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
||||
def get_prefix_cache_hit_rate(self) -> float:
|
||||
return self.metric_data.get_hit_rate()
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
"""Reset prefix cache. This function may be used in RLHF
|
||||
flows to invalid prefix caching after the weights are updated,
|
||||
or used for resetting prefix caching status for benchmarking.
|
||||
|
||||
Returns:
|
||||
bool: True if the prefix cache is successfully reset,
|
||||
False otherwise.
|
||||
"""
|
||||
num_used_blocks = (self.get_num_total_blocks() -
|
||||
self.get_num_free_blocks())
|
||||
if num_used_blocks > 0:
|
||||
logger.warning(
|
||||
"Failed to reset prefix cache because some "
|
||||
"blocks (%d) are not freed yet", num_used_blocks)
|
||||
return False
|
||||
|
||||
# Free all blocks in the evictor.
|
||||
while (block_id :=
|
||||
self._maybe_allocate_evicted_block_id()) is not None:
|
||||
self._hashless_allocator.free_block_id(block_id)
|
||||
|
||||
# Should not have any cached blocks because all blocks are evicted.
|
||||
assert not self._cached_blocks
|
||||
|
||||
# Reset the evictor.
|
||||
self.evictor = make_evictor(self.eviction_policy)
|
||||
|
||||
# Reset the block tracker.
|
||||
for block_id in self._block_tracker:
|
||||
self._block_tracker[block_id] = BlockTracker()
|
||||
|
||||
# Reset the metrics.
|
||||
self.metric_data = CacheMetricData()
|
||||
|
||||
logger.info("Successfully reset prefix cache")
|
||||
return True
|
||||
|
||||
def is_block_cached(self, block: Block) -> bool:
|
||||
assert block.content_hash is not None
|
||||
return block.content_hash in self._cached_blocks
|
||||
|
||||
@ -455,6 +455,9 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
|
||||
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
||||
return self.block_allocator.get_prefix_cache_hit_rate(device)
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
return self.block_allocator.reset_prefix_cache()
|
||||
|
||||
def _can_swap(self,
|
||||
seq_group: SequenceGroup,
|
||||
device: Device,
|
||||
|
||||
@ -122,6 +122,11 @@ class BlockSpaceManager(ABC):
|
||||
"""Prefix cache hit rate. -1 means not supported or disabled."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
"""Reset prefix cache for all devices."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_num_cached_tokens(self, seq: Sequence) -> int:
|
||||
pass
|
||||
|
||||
@ -90,5 +90,8 @@ class PlaceholderBlockSpaceManager(BlockSpaceManager):
|
||||
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
||||
return -1
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
return True
|
||||
|
||||
def get_num_cached_tokens(self, seq: Sequence) -> int:
|
||||
return 0
|
||||
|
||||
@ -504,6 +504,9 @@ class Scheduler:
|
||||
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
||||
return self.block_manager.get_prefix_cache_hit_rate(device)
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
return self.block_manager.reset_prefix_cache()
|
||||
|
||||
def get_num_unfinished_seq_groups(self) -> int:
|
||||
return len(self.waiting) + len(self.running) + len(self.swapped)
|
||||
|
||||
|
||||
@ -1182,6 +1182,9 @@ class AsyncLLMEngine(EngineClient):
|
||||
async def stop_profile(self) -> None:
|
||||
self.engine.stop_profile()
|
||||
|
||||
async def reset_prefix_cache(self) -> None:
|
||||
self.engine.reset_prefix_cache()
|
||||
|
||||
async def add_lora(self, lora_request: LoRARequest) -> None:
|
||||
self.engine.add_lora(lora_request)
|
||||
|
||||
|
||||
@ -914,6 +914,14 @@ class LLMEngine:
|
||||
"""
|
||||
return self.scheduler[virtual_engine].has_unfinished_seqs()
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
"""Reset prefix cache for all devices."""
|
||||
|
||||
success = True
|
||||
for scheduler in self.scheduler:
|
||||
success = success and scheduler.reset_prefix_cache()
|
||||
return success
|
||||
|
||||
@staticmethod
|
||||
def _process_sequence_group_outputs(
|
||||
seq_group: SequenceGroup,
|
||||
|
||||
@ -121,6 +121,10 @@ class RPCUProfileRequest(Enum):
|
||||
STOP_PROFILE = 2
|
||||
|
||||
|
||||
class RPCResetPrefixCacheRequest(Enum):
|
||||
RESET_PREFIX_CACHE = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class RPCLoadAdapterRequest:
|
||||
lora_request: LoRARequest
|
||||
@ -134,7 +138,8 @@ class RPCAdapterLoadedResponse:
|
||||
|
||||
|
||||
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest,
|
||||
RPCUProfileRequest, RPCLoadAdapterRequest]
|
||||
RPCUProfileRequest, RPCLoadAdapterRequest,
|
||||
RPCResetPrefixCacheRequest]
|
||||
|
||||
REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCAdapterLoadedResponse,
|
||||
RPCError]
|
||||
|
||||
@ -27,8 +27,9 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
|
||||
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
|
||||
RPCAdapterLoadedResponse, RPCError,
|
||||
RPCLoadAdapterRequest,
|
||||
RPCProcessRequest, RPCStartupRequest,
|
||||
RPCStartupResponse,
|
||||
RPCProcessRequest,
|
||||
RPCResetPrefixCacheRequest,
|
||||
RPCStartupRequest, RPCStartupResponse,
|
||||
RPCUProfileRequest)
|
||||
from vllm.engine.protocol import EngineClient
|
||||
# yapf: enable
|
||||
@ -675,6 +676,13 @@ class MQLLMEngineClient(EngineClient):
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket)
|
||||
|
||||
async def reset_prefix_cache(self) -> None:
|
||||
"""Reset the prefix cache"""
|
||||
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCResetPrefixCacheRequest.RESET_PREFIX_CACHE,
|
||||
socket=self.input_socket)
|
||||
|
||||
async def add_lora(self, lora_request: LoRARequest) -> None:
|
||||
"""Load a new LoRA adapter into the engine for future requests."""
|
||||
# Uses the same I/O as generate requests
|
||||
|
||||
@ -16,8 +16,9 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
|
||||
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
|
||||
RPCAdapterLoadedResponse, RPCError,
|
||||
RPCLoadAdapterRequest,
|
||||
RPCProcessRequest, RPCStartupRequest,
|
||||
RPCStartupResponse,
|
||||
RPCProcessRequest,
|
||||
RPCResetPrefixCacheRequest,
|
||||
RPCStartupRequest, RPCStartupResponse,
|
||||
RPCUProfileRequest)
|
||||
# yapf: enable
|
||||
from vllm.logger import init_logger
|
||||
@ -237,6 +238,8 @@ class MQLLMEngine:
|
||||
self.stop_profile()
|
||||
elif isinstance(request, RPCLoadAdapterRequest):
|
||||
self._handle_load_adapter_request(request)
|
||||
elif isinstance(request, RPCResetPrefixCacheRequest):
|
||||
self.reset_prefix_cache()
|
||||
else:
|
||||
raise ValueError("Unknown RPCRequest Type: "
|
||||
f"{type(request)}")
|
||||
@ -361,6 +364,9 @@ class MQLLMEngine:
|
||||
def stop_profile(self) -> None:
|
||||
self.engine.stop_profile()
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
return self.engine.reset_prefix_cache()
|
||||
|
||||
|
||||
def signal_handler(*_) -> None:
|
||||
raise KeyboardInterrupt("MQLLMEngine terminated")
|
||||
|
||||
@ -271,6 +271,11 @@ class EngineClient(ABC):
|
||||
"""Start profiling the engine"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def reset_prefix_cache(self) -> None:
|
||||
"""Reset the prefix cache"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def add_lora(self, lora_request: LoRARequest) -> None:
|
||||
"""Load a new LoRA adapter into the engine for future requests."""
|
||||
|
||||
@ -1132,6 +1132,9 @@ class LLM:
|
||||
def stop_profile(self) -> None:
|
||||
self.llm_engine.stop_profile()
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
return self.llm_engine.reset_prefix_cache()
|
||||
|
||||
def sleep(self, level: int = 1):
|
||||
"""
|
||||
Put the engine to sleep. The engine should not process any requests.
|
||||
@ -1150,6 +1153,7 @@ class LLM:
|
||||
where previous model weights are not needed. It reduces CPU memory
|
||||
pressure.
|
||||
"""
|
||||
self.reset_prefix_cache()
|
||||
self.llm_engine.sleep(level=level)
|
||||
|
||||
def wake_up(self):
|
||||
|
||||
@ -518,6 +518,18 @@ TASK_HANDLERS: Dict[str, Dict[str, tuple]] = {
|
||||
},
|
||||
}
|
||||
|
||||
if envs.VLLM_SERVER_DEV_MODE:
|
||||
|
||||
@router.post("/reset_prefix_cache")
|
||||
async def reset_prefix_cache(raw_request: Request):
|
||||
"""
|
||||
Reset the prefix cache. Note that we currently do not check if the
|
||||
prefix cache is successfully reset in the API server.
|
||||
"""
|
||||
logger.info("Resetting prefix cache...")
|
||||
await engine_client(raw_request).reset_prefix_cache()
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@router.post("/invocations")
|
||||
async def invocations(raw_request: Request):
|
||||
|
||||
@ -72,6 +72,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
|
||||
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
|
||||
VLLM_DISABLE_COMPILE_CACHE: bool = False
|
||||
VLLM_SERVER_DEV_MODE: bool = False
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -467,6 +468,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
||||
lambda: float(os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1")),
|
||||
"VLLM_DISABLE_COMPILE_CACHE":
|
||||
lambda: bool(int(os.getenv("VLLM_DISABLE_COMPILE_CACHE", "0"))),
|
||||
|
||||
# If set, vllm will run in development mode, which will enable
|
||||
# some additional endpoints for developing and debugging,
|
||||
# e.g. `/reset_prefix_cache`
|
||||
"VLLM_SERVER_DEV_MODE":
|
||||
lambda: bool(int(os.getenv("VLLM_SERVER_DEV_MODE", "0"))),
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
||||
@ -194,11 +194,6 @@ class ExecutorBase(ABC):
|
||||
self.collective_rpc("stop_profile")
|
||||
|
||||
def sleep(self, level: int = 1):
|
||||
if self.cache_config.enable_prefix_caching:
|
||||
# TODO: support sleep with prefix caching
|
||||
# by resetting the prefix cache state,
|
||||
# after https://github.com/vllm-project/vllm/pull/12284
|
||||
raise ValueError("Cannot sleep when prefix caching is enabled.")
|
||||
self.collective_rpc("sleep", kwargs=dict(level=level))
|
||||
|
||||
def wake_up(self):
|
||||
|
||||
@ -285,6 +285,33 @@ class KVCacheManager:
|
||||
if block.ref_cnt == 0:
|
||||
self.free_block_queue.append(block)
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
"""Reset prefix cache. This function may be used in RLHF
|
||||
flows to invalid prefix caching after the weights are updated,
|
||||
or used for resetting prefix caching status for benchmarking.
|
||||
|
||||
Returns:
|
||||
bool: True if the prefix cache is successfully reset,
|
||||
False otherwise.
|
||||
"""
|
||||
num_used_blocks = (self.num_gpu_blocks -
|
||||
self.free_block_queue.num_free_blocks)
|
||||
if num_used_blocks > 0:
|
||||
logger.warning(
|
||||
"Failed to reset prefix cache because some "
|
||||
"blocks (%d) are not freed yet", num_used_blocks)
|
||||
return False
|
||||
|
||||
# Remove all hashes so that no new blocks will hit.
|
||||
self.cached_block_hash_to_block = defaultdict(dict)
|
||||
|
||||
# Remove all hashes from all blocks.
|
||||
for block in self.block_pool:
|
||||
block.reset_hash()
|
||||
|
||||
logger.info("Successfully reset prefix cache")
|
||||
return True
|
||||
|
||||
def get_num_common_prefix_blocks(
|
||||
self,
|
||||
request: Request,
|
||||
|
||||
@ -529,6 +529,9 @@ class Scheduler:
|
||||
def has_unfinished_requests(self) -> bool:
|
||||
return self.get_num_unfinished_requests() > 0
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
return self.kv_cache_manager.reset_prefix_cache()
|
||||
|
||||
def make_stats(self) -> SchedulerStats:
|
||||
return SchedulerStats(
|
||||
num_running_reqs=len(self.running),
|
||||
|
||||
@ -66,6 +66,11 @@ class EngineCoreProfile:
|
||||
is_start: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class EngineCoreResetPrefixCache:
|
||||
pass
|
||||
|
||||
|
||||
class EngineCoreRequestType(enum.Enum):
|
||||
"""
|
||||
Request types defined as hex byte strings, so it can be sent over sockets
|
||||
@ -74,6 +79,8 @@ class EngineCoreRequestType(enum.Enum):
|
||||
ADD = b'\x00'
|
||||
ABORT = b'\x01'
|
||||
PROFILE = b'\x02'
|
||||
RESET_PREFIX_CACHE = b'\x03'
|
||||
|
||||
|
||||
EngineCoreRequestUnion = Union[EngineCoreRequest, EngineCoreProfile, List[str]]
|
||||
EngineCoreRequestUnion = Union[EngineCoreRequest, EngineCoreProfile,
|
||||
EngineCoreResetPrefixCache, List[str]]
|
||||
|
||||
@ -321,6 +321,9 @@ class AsyncLLM(EngineClient):
|
||||
async def stop_profile(self) -> None:
|
||||
await self.engine_core.profile_async(False)
|
||||
|
||||
async def reset_prefix_cache(self) -> None:
|
||||
await self.engine_core.reset_prefix_cache_async()
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return True
|
||||
|
||||
@ -20,7 +20,7 @@ from vllm.v1.core.kv_cache_utils import get_kv_cache_config
|
||||
from vllm.v1.core.scheduler import Scheduler
|
||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
|
||||
EngineCoreRequest, EngineCoreRequestType,
|
||||
EngineCoreRequestUnion)
|
||||
EngineCoreRequestUnion, EngineCoreResetPrefixCache)
|
||||
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
@ -135,6 +135,9 @@ class EngineCore:
|
||||
def profile(self, is_start: bool = True):
|
||||
self.model_executor.profile(is_start)
|
||||
|
||||
def reset_prefix_cache(self):
|
||||
self.scheduler.reset_prefix_cache()
|
||||
|
||||
|
||||
class EngineCoreProc(EngineCore):
|
||||
"""ZMQ-wrapper for running EngineCore in background process."""
|
||||
@ -247,6 +250,8 @@ class EngineCoreProc(EngineCore):
|
||||
self.add_request(request)
|
||||
elif isinstance(request, EngineCoreProfile):
|
||||
self.model_executor.profile(request.is_start)
|
||||
elif isinstance(request, EngineCoreResetPrefixCache):
|
||||
self.reset_prefix_cache()
|
||||
else:
|
||||
# TODO: make an EngineCoreAbort wrapper
|
||||
assert isinstance(request, list)
|
||||
@ -271,7 +276,9 @@ class EngineCoreProc(EngineCore):
|
||||
request = decoder_add_req.decode(request_data)
|
||||
elif request_type == EngineCoreRequestType.ABORT.value:
|
||||
request = decoder_abort_req.decode(request_data)
|
||||
elif request_type == EngineCoreRequestType.PROFILE.value:
|
||||
elif request_type in (
|
||||
EngineCoreRequestType.PROFILE.value,
|
||||
EngineCoreRequestType.RESET_PREFIX_CACHE.value):
|
||||
request = pickle.loads(request_data)
|
||||
else:
|
||||
raise ValueError(f"Unknown RequestType: {request_type}")
|
||||
|
||||
@ -14,7 +14,7 @@ from vllm.utils import (get_open_zmq_ipc_path, kill_process_tree,
|
||||
make_zmq_socket)
|
||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
|
||||
EngineCoreRequest, EngineCoreRequestType,
|
||||
EngineCoreRequestUnion)
|
||||
EngineCoreRequestUnion, EngineCoreResetPrefixCache)
|
||||
from vllm.v1.engine.core import EngineCore, EngineCoreProc
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.serial_utils import PickleEncoder
|
||||
@ -69,6 +69,9 @@ class EngineCoreClient(ABC):
|
||||
def profile(self, is_start: bool = True) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def reset_prefix_cache(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def abort_requests(self, request_ids: List[str]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -81,6 +84,9 @@ class EngineCoreClient(ABC):
|
||||
async def profile_async(self, is_start: bool = True) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def reset_prefix_cache_async(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def abort_requests_async(self, request_ids: List[str]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -108,12 +114,15 @@ class InprocClient(EngineCoreClient):
|
||||
if len(request_ids) > 0:
|
||||
self.engine_core.abort_requests(request_ids)
|
||||
|
||||
def shutdown(self):
|
||||
def shutdown(self) -> None:
|
||||
self.engine_core.shutdown()
|
||||
|
||||
def profile(self, is_start: bool = True) -> None:
|
||||
self.engine_core.profile(is_start)
|
||||
|
||||
def reset_prefix_cache(self) -> None:
|
||||
self.engine_core.reset_prefix_cache()
|
||||
|
||||
|
||||
class MPClient(EngineCoreClient):
|
||||
"""
|
||||
@ -229,6 +238,10 @@ class SyncMPClient(MPClient):
|
||||
self._send_input(EngineCoreRequestType.PROFILE,
|
||||
EngineCoreProfile(is_start))
|
||||
|
||||
def reset_prefix_cache(self) -> None:
|
||||
self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE,
|
||||
EngineCoreResetPrefixCache())
|
||||
|
||||
|
||||
class AsyncMPClient(MPClient):
|
||||
"""Asyncio-compatible client for multi-proc EngineCore."""
|
||||
@ -266,3 +279,7 @@ class AsyncMPClient(MPClient):
|
||||
async def profile_async(self, is_start: bool = True) -> None:
|
||||
await self._send_input(EngineCoreRequestType.PROFILE,
|
||||
EngineCoreProfile(is_start))
|
||||
|
||||
async def reset_prefix_cache_async(self) -> None:
|
||||
await self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE,
|
||||
EngineCoreResetPrefixCache())
|
||||
|
||||
@ -162,6 +162,9 @@ class LLMEngine:
|
||||
def stop_profile(self):
|
||||
self.engine_core.profile(False)
|
||||
|
||||
def reset_prefix_cache(self):
|
||||
self.engine_core.reset_prefix_cache()
|
||||
|
||||
def get_tokenizer_group(
|
||||
self,
|
||||
group_type: Type[_G] = BaseTokenizerGroup,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user