mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-23 02:05:01 +08:00
[Core] Support reset_prefix_cache (#12284)
This commit is contained in:
parent
96f6a7596f
commit
7206ce4ce1
@ -796,6 +796,44 @@ class TestPrefixCachingBlockAllocator:
|
|||||||
block_hashes=block_hashes_seq1)
|
block_hashes=block_hashes_seq1)
|
||||||
assert len(cached_blocks) == len(blocks_seq1) - num_evicted_blocks
|
assert len(cached_blocks) == len(blocks_seq1) - num_evicted_blocks
|
||||||
|
|
||||||
|
# Test reset prefix cache
|
||||||
|
@staticmethod
|
||||||
|
@pytest.mark.parametrize("num_blocks", [10])
|
||||||
|
@pytest.mark.parametrize("block_size", [16])
|
||||||
|
def test_reset_prefix_cache(num_blocks: int, block_size: int):
|
||||||
|
"""This test case simulates the case of resetting the prefix cache."""
|
||||||
|
|
||||||
|
allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
|
||||||
|
block_size=block_size)
|
||||||
|
token_ids = list(range(3 * block_size))
|
||||||
|
|
||||||
|
first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
|
||||||
|
block_size=block_size,
|
||||||
|
token_ids=token_ids,
|
||||||
|
allocator=allocator,
|
||||||
|
)
|
||||||
|
second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
|
||||||
|
block_size=block_size,
|
||||||
|
token_ids=token_ids,
|
||||||
|
allocator=allocator,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Free each block in the first chain.
|
||||||
|
for block in first_chain:
|
||||||
|
allocator.free(block)
|
||||||
|
|
||||||
|
# Failed to reset prefix cache because some blocks are not freed yet.
|
||||||
|
assert not allocator.reset_prefix_cache()
|
||||||
|
assert allocator.get_prefix_cache_hit_rate() > 0.0
|
||||||
|
|
||||||
|
# Free each block in the second chain.
|
||||||
|
for block in second_chain:
|
||||||
|
allocator.free(block)
|
||||||
|
|
||||||
|
# Reset prefix cache.
|
||||||
|
assert allocator.reset_prefix_cache()
|
||||||
|
assert allocator.get_prefix_cache_hit_rate() == 0.0
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_immutable_chain(
|
def create_immutable_chain(
|
||||||
block_size: int,
|
block_size: int,
|
||||||
|
|||||||
@ -587,3 +587,42 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
|||||||
assert {block.ref_cnt for block in block_part1[:3]} == {1}
|
assert {block.ref_cnt for block in block_part1[:3]} == {1}
|
||||||
# Block 3-5 are free.
|
# Block 3-5 are free.
|
||||||
assert {block.ref_cnt for block in block_part1[3:]} == {0}
|
assert {block.ref_cnt for block in block_part1[3:]} == {0}
|
||||||
|
|
||||||
|
|
||||||
|
def test_reset_prefix_cache():
|
||||||
|
manager = KVCacheManager(
|
||||||
|
block_size=16,
|
||||||
|
num_gpu_blocks=10,
|
||||||
|
max_model_len=8192,
|
||||||
|
sliding_window=None,
|
||||||
|
enable_caching=True,
|
||||||
|
num_preallocate_tokens=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
full_block_token_ids = [i for i in range(3) for _ in range(16)]
|
||||||
|
unique_token_ids = [3] * 7
|
||||||
|
all_token_ids = full_block_token_ids + unique_token_ids
|
||||||
|
req0 = make_request("0", all_token_ids)
|
||||||
|
blocks = manager.allocate_slots(req0, 55, [])
|
||||||
|
assert [b.block_id for b in blocks] == [0, 1, 2, 3]
|
||||||
|
|
||||||
|
unique_token_ids = [4] * 7
|
||||||
|
all_token_ids = full_block_token_ids + unique_token_ids
|
||||||
|
req1 = make_request("1", all_token_ids)
|
||||||
|
computed_blocks, _ = manager.get_computed_blocks(req1)
|
||||||
|
assert len(req1.kv_block_hashes) == 3
|
||||||
|
assert len(computed_blocks) == 3
|
||||||
|
blocks = manager.allocate_slots(req1, 7, computed_blocks)
|
||||||
|
assert [b.block_id for b in blocks] == [4]
|
||||||
|
|
||||||
|
# Failed to reset prefix cache because some blocks are not freed yet.
|
||||||
|
assert not manager.reset_prefix_cache()
|
||||||
|
assert manager.cached_block_hash_to_block
|
||||||
|
|
||||||
|
# Free the blocks.
|
||||||
|
manager.free(req0)
|
||||||
|
manager.free(req1)
|
||||||
|
|
||||||
|
assert manager.reset_prefix_cache()
|
||||||
|
assert not manager.cached_block_hash_to_block
|
||||||
|
assert all([blk.block_hash is None for blk in manager.block_pool])
|
||||||
|
|||||||
@ -339,6 +339,13 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
|
|||||||
assert device in self._allocators
|
assert device in self._allocators
|
||||||
return self._allocators[device].get_prefix_cache_hit_rate()
|
return self._allocators[device].get_prefix_cache_hit_rate()
|
||||||
|
|
||||||
|
def reset_prefix_cache(self) -> bool:
|
||||||
|
"""Reset prefix cache for all devices."""
|
||||||
|
success = True
|
||||||
|
for allocator in self._allocators.values():
|
||||||
|
success = success and allocator.reset_prefix_cache()
|
||||||
|
return success
|
||||||
|
|
||||||
def get_and_reset_swaps(self) -> List[Tuple[int, int]]:
|
def get_and_reset_swaps(self) -> List[Tuple[int, int]]:
|
||||||
"""Returns and clears the mapping of source to destination block IDs.
|
"""Returns and clears the mapping of source to destination block IDs.
|
||||||
Will be called after every swapping operations for now, and after every
|
Will be called after every swapping operations for now, and after every
|
||||||
|
|||||||
@ -192,6 +192,11 @@ class BlockAllocator(ABC):
|
|||||||
"""Prefix cache hit rate. -1 means not supported or disabled."""
|
"""Prefix cache hit rate. -1 means not supported or disabled."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def reset_prefix_cache(self) -> bool:
|
||||||
|
"""Reset prefix cache."""
|
||||||
|
pass
|
||||||
|
|
||||||
class NoFreeBlocksError(ValueError):
|
class NoFreeBlocksError(ValueError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -297,6 +302,11 @@ class DeviceAwareBlockAllocator(ABC):
|
|||||||
"""Prefix cache hit rate. -1 means not supported or disabled."""
|
"""Prefix cache hit rate. -1 means not supported or disabled."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def reset_prefix_cache(self) -> bool:
|
||||||
|
"""Reset prefix cache."""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def find_cached_blocks_prefix(
|
def find_cached_blocks_prefix(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple
|
from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from vllm.core.block.common import (BlockPool, CopyOnWriteTracker, RefCounter,
|
from vllm.core.block.common import (BlockPool, CopyOnWriteTracker, RefCounter,
|
||||||
get_all_blocks_recursively)
|
get_all_blocks_recursively)
|
||||||
@ -136,16 +136,18 @@ class NaiveBlockAllocator(BlockAllocator):
|
|||||||
self._refcounter.incr(block_id)
|
self._refcounter.incr(block_id)
|
||||||
return block_id
|
return block_id
|
||||||
|
|
||||||
def _free_block_id(self, block: Block) -> None:
|
def _free_block_id(self, block: Union[Block, BlockId]) -> None:
|
||||||
|
if isinstance(block, Block):
|
||||||
block_id = block.block_id
|
block_id = block.block_id
|
||||||
|
block.block_id = None
|
||||||
|
else:
|
||||||
|
block_id = block
|
||||||
assert block_id is not None
|
assert block_id is not None
|
||||||
|
|
||||||
refcount = self._refcounter.decr(block_id)
|
refcount = self._refcounter.decr(block_id)
|
||||||
if refcount == 0:
|
if refcount == 0:
|
||||||
self._free_block_indices.appendleft(block_id)
|
self._free_block_indices.appendleft(block_id)
|
||||||
|
|
||||||
block.block_id = None
|
|
||||||
|
|
||||||
def free(self, block: Block, keep_block_object: bool = False) -> None:
|
def free(self, block: Block, keep_block_object: bool = False) -> None:
|
||||||
# Release the physical block id
|
# Release the physical block id
|
||||||
self._free_block_id(block)
|
self._free_block_id(block)
|
||||||
@ -154,6 +156,9 @@ class NaiveBlockAllocator(BlockAllocator):
|
|||||||
if not keep_block_object:
|
if not keep_block_object:
|
||||||
self._block_pool.free_block(block)
|
self._block_pool.free_block(block)
|
||||||
|
|
||||||
|
def free_block_id(self, block_id: BlockId) -> None:
|
||||||
|
self._free_block_id(block_id)
|
||||||
|
|
||||||
def fork(self, last_block: Block) -> List[Block]:
|
def fork(self, last_block: Block) -> List[Block]:
|
||||||
"""Creates a new sequence of blocks that shares the same underlying
|
"""Creates a new sequence of blocks that shares the same underlying
|
||||||
memory as the original sequence.
|
memory as the original sequence.
|
||||||
@ -325,6 +330,10 @@ class NaiveBlockAllocator(BlockAllocator):
|
|||||||
def get_prefix_cache_hit_rate(self) -> float:
|
def get_prefix_cache_hit_rate(self) -> float:
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
|
def reset_prefix_cache(self) -> bool:
|
||||||
|
"""No prefix cache for naive block allocator."""
|
||||||
|
return True
|
||||||
|
|
||||||
def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]:
|
def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]:
|
||||||
# Not applicable for naive block allocator.
|
# Not applicable for naive block allocator.
|
||||||
return []
|
return []
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId, Device,
|
|||||||
from vllm.core.block.naive_block import (BlockPool, NaiveBlock,
|
from vllm.core.block.naive_block import (BlockPool, NaiveBlock,
|
||||||
NaiveBlockAllocator)
|
NaiveBlockAllocator)
|
||||||
from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor
|
from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.sequence import Sequence
|
from vllm.sequence import Sequence
|
||||||
|
|
||||||
PrefixHash = int
|
PrefixHash = int
|
||||||
@ -21,6 +22,8 @@ PrefixHash = int
|
|||||||
# then we know this block hasn't been accessed yet.
|
# then we know this block hasn't been accessed yet.
|
||||||
_DEFAULT_LAST_ACCESSED_TIME = -1
|
_DEFAULT_LAST_ACCESSED_TIME = -1
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BlockTracker:
|
class BlockTracker:
|
||||||
"""Used to track the status of a block inside the prefix caching allocator
|
"""Used to track the status of a block inside the prefix caching allocator
|
||||||
@ -105,7 +108,8 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|||||||
|
|
||||||
# Evitor used to maintain how we want to handle those computed blocks
|
# Evitor used to maintain how we want to handle those computed blocks
|
||||||
# if we find memory pressure is high.
|
# if we find memory pressure is high.
|
||||||
self.evictor: Evictor = make_evictor(eviction_policy)
|
self.eviction_policy = eviction_policy
|
||||||
|
self.evictor: Evictor = make_evictor(self.eviction_policy)
|
||||||
|
|
||||||
# We share the refcounter between allocators. This allows us to promote
|
# We share the refcounter between allocators. This allows us to promote
|
||||||
# blocks originally allocated in the hashless allocator to immutable
|
# blocks originally allocated in the hashless allocator to immutable
|
||||||
@ -428,6 +432,44 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|||||||
def get_prefix_cache_hit_rate(self) -> float:
|
def get_prefix_cache_hit_rate(self) -> float:
|
||||||
return self.metric_data.get_hit_rate()
|
return self.metric_data.get_hit_rate()
|
||||||
|
|
||||||
|
def reset_prefix_cache(self) -> bool:
|
||||||
|
"""Reset prefix cache. This function may be used in RLHF
|
||||||
|
flows to invalid prefix caching after the weights are updated,
|
||||||
|
or used for resetting prefix caching status for benchmarking.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the prefix cache is successfully reset,
|
||||||
|
False otherwise.
|
||||||
|
"""
|
||||||
|
num_used_blocks = (self.get_num_total_blocks() -
|
||||||
|
self.get_num_free_blocks())
|
||||||
|
if num_used_blocks > 0:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to reset prefix cache because some "
|
||||||
|
"blocks (%d) are not freed yet", num_used_blocks)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Free all blocks in the evictor.
|
||||||
|
while (block_id :=
|
||||||
|
self._maybe_allocate_evicted_block_id()) is not None:
|
||||||
|
self._hashless_allocator.free_block_id(block_id)
|
||||||
|
|
||||||
|
# Should not have any cached blocks because all blocks are evicted.
|
||||||
|
assert not self._cached_blocks
|
||||||
|
|
||||||
|
# Reset the evictor.
|
||||||
|
self.evictor = make_evictor(self.eviction_policy)
|
||||||
|
|
||||||
|
# Reset the block tracker.
|
||||||
|
for block_id in self._block_tracker:
|
||||||
|
self._block_tracker[block_id] = BlockTracker()
|
||||||
|
|
||||||
|
# Reset the metrics.
|
||||||
|
self.metric_data = CacheMetricData()
|
||||||
|
|
||||||
|
logger.info("Successfully reset prefix cache")
|
||||||
|
return True
|
||||||
|
|
||||||
def is_block_cached(self, block: Block) -> bool:
|
def is_block_cached(self, block: Block) -> bool:
|
||||||
assert block.content_hash is not None
|
assert block.content_hash is not None
|
||||||
return block.content_hash in self._cached_blocks
|
return block.content_hash in self._cached_blocks
|
||||||
|
|||||||
@ -455,6 +455,9 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
|
|||||||
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
||||||
return self.block_allocator.get_prefix_cache_hit_rate(device)
|
return self.block_allocator.get_prefix_cache_hit_rate(device)
|
||||||
|
|
||||||
|
def reset_prefix_cache(self) -> bool:
|
||||||
|
return self.block_allocator.reset_prefix_cache()
|
||||||
|
|
||||||
def _can_swap(self,
|
def _can_swap(self,
|
||||||
seq_group: SequenceGroup,
|
seq_group: SequenceGroup,
|
||||||
device: Device,
|
device: Device,
|
||||||
|
|||||||
@ -122,6 +122,11 @@ class BlockSpaceManager(ABC):
|
|||||||
"""Prefix cache hit rate. -1 means not supported or disabled."""
|
"""Prefix cache hit rate. -1 means not supported or disabled."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def reset_prefix_cache(self) -> bool:
|
||||||
|
"""Reset prefix cache for all devices."""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_num_cached_tokens(self, seq: Sequence) -> int:
|
def get_num_cached_tokens(self, seq: Sequence) -> int:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -90,5 +90,8 @@ class PlaceholderBlockSpaceManager(BlockSpaceManager):
|
|||||||
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
|
def reset_prefix_cache(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
def get_num_cached_tokens(self, seq: Sequence) -> int:
|
def get_num_cached_tokens(self, seq: Sequence) -> int:
|
||||||
return 0
|
return 0
|
||||||
|
|||||||
@ -504,6 +504,9 @@ class Scheduler:
|
|||||||
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
||||||
return self.block_manager.get_prefix_cache_hit_rate(device)
|
return self.block_manager.get_prefix_cache_hit_rate(device)
|
||||||
|
|
||||||
|
def reset_prefix_cache(self) -> bool:
|
||||||
|
return self.block_manager.reset_prefix_cache()
|
||||||
|
|
||||||
def get_num_unfinished_seq_groups(self) -> int:
|
def get_num_unfinished_seq_groups(self) -> int:
|
||||||
return len(self.waiting) + len(self.running) + len(self.swapped)
|
return len(self.waiting) + len(self.running) + len(self.swapped)
|
||||||
|
|
||||||
|
|||||||
@ -1182,6 +1182,9 @@ class AsyncLLMEngine(EngineClient):
|
|||||||
async def stop_profile(self) -> None:
|
async def stop_profile(self) -> None:
|
||||||
self.engine.stop_profile()
|
self.engine.stop_profile()
|
||||||
|
|
||||||
|
async def reset_prefix_cache(self) -> None:
|
||||||
|
self.engine.reset_prefix_cache()
|
||||||
|
|
||||||
async def add_lora(self, lora_request: LoRARequest) -> None:
|
async def add_lora(self, lora_request: LoRARequest) -> None:
|
||||||
self.engine.add_lora(lora_request)
|
self.engine.add_lora(lora_request)
|
||||||
|
|
||||||
|
|||||||
@ -914,6 +914,14 @@ class LLMEngine:
|
|||||||
"""
|
"""
|
||||||
return self.scheduler[virtual_engine].has_unfinished_seqs()
|
return self.scheduler[virtual_engine].has_unfinished_seqs()
|
||||||
|
|
||||||
|
def reset_prefix_cache(self) -> bool:
|
||||||
|
"""Reset prefix cache for all devices."""
|
||||||
|
|
||||||
|
success = True
|
||||||
|
for scheduler in self.scheduler:
|
||||||
|
success = success and scheduler.reset_prefix_cache()
|
||||||
|
return success
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _process_sequence_group_outputs(
|
def _process_sequence_group_outputs(
|
||||||
seq_group: SequenceGroup,
|
seq_group: SequenceGroup,
|
||||||
|
|||||||
@ -121,6 +121,10 @@ class RPCUProfileRequest(Enum):
|
|||||||
STOP_PROFILE = 2
|
STOP_PROFILE = 2
|
||||||
|
|
||||||
|
|
||||||
|
class RPCResetPrefixCacheRequest(Enum):
|
||||||
|
RESET_PREFIX_CACHE = 1
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RPCLoadAdapterRequest:
|
class RPCLoadAdapterRequest:
|
||||||
lora_request: LoRARequest
|
lora_request: LoRARequest
|
||||||
@ -134,7 +138,8 @@ class RPCAdapterLoadedResponse:
|
|||||||
|
|
||||||
|
|
||||||
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest,
|
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest,
|
||||||
RPCUProfileRequest, RPCLoadAdapterRequest]
|
RPCUProfileRequest, RPCLoadAdapterRequest,
|
||||||
|
RPCResetPrefixCacheRequest]
|
||||||
|
|
||||||
REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCAdapterLoadedResponse,
|
REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCAdapterLoadedResponse,
|
||||||
RPCError]
|
RPCError]
|
||||||
|
|||||||
@ -27,8 +27,9 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
|
|||||||
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
|
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
|
||||||
RPCAdapterLoadedResponse, RPCError,
|
RPCAdapterLoadedResponse, RPCError,
|
||||||
RPCLoadAdapterRequest,
|
RPCLoadAdapterRequest,
|
||||||
RPCProcessRequest, RPCStartupRequest,
|
RPCProcessRequest,
|
||||||
RPCStartupResponse,
|
RPCResetPrefixCacheRequest,
|
||||||
|
RPCStartupRequest, RPCStartupResponse,
|
||||||
RPCUProfileRequest)
|
RPCUProfileRequest)
|
||||||
from vllm.engine.protocol import EngineClient
|
from vllm.engine.protocol import EngineClient
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
@ -675,6 +676,13 @@ class MQLLMEngineClient(EngineClient):
|
|||||||
await self._send_one_way_rpc_request(
|
await self._send_one_way_rpc_request(
|
||||||
request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket)
|
request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket)
|
||||||
|
|
||||||
|
async def reset_prefix_cache(self) -> None:
|
||||||
|
"""Reset the prefix cache"""
|
||||||
|
|
||||||
|
await self._send_one_way_rpc_request(
|
||||||
|
request=RPCResetPrefixCacheRequest.RESET_PREFIX_CACHE,
|
||||||
|
socket=self.input_socket)
|
||||||
|
|
||||||
async def add_lora(self, lora_request: LoRARequest) -> None:
|
async def add_lora(self, lora_request: LoRARequest) -> None:
|
||||||
"""Load a new LoRA adapter into the engine for future requests."""
|
"""Load a new LoRA adapter into the engine for future requests."""
|
||||||
# Uses the same I/O as generate requests
|
# Uses the same I/O as generate requests
|
||||||
|
|||||||
@ -16,8 +16,9 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
|
|||||||
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
|
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
|
||||||
RPCAdapterLoadedResponse, RPCError,
|
RPCAdapterLoadedResponse, RPCError,
|
||||||
RPCLoadAdapterRequest,
|
RPCLoadAdapterRequest,
|
||||||
RPCProcessRequest, RPCStartupRequest,
|
RPCProcessRequest,
|
||||||
RPCStartupResponse,
|
RPCResetPrefixCacheRequest,
|
||||||
|
RPCStartupRequest, RPCStartupResponse,
|
||||||
RPCUProfileRequest)
|
RPCUProfileRequest)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -237,6 +238,8 @@ class MQLLMEngine:
|
|||||||
self.stop_profile()
|
self.stop_profile()
|
||||||
elif isinstance(request, RPCLoadAdapterRequest):
|
elif isinstance(request, RPCLoadAdapterRequest):
|
||||||
self._handle_load_adapter_request(request)
|
self._handle_load_adapter_request(request)
|
||||||
|
elif isinstance(request, RPCResetPrefixCacheRequest):
|
||||||
|
self.reset_prefix_cache()
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown RPCRequest Type: "
|
raise ValueError("Unknown RPCRequest Type: "
|
||||||
f"{type(request)}")
|
f"{type(request)}")
|
||||||
@ -361,6 +364,9 @@ class MQLLMEngine:
|
|||||||
def stop_profile(self) -> None:
|
def stop_profile(self) -> None:
|
||||||
self.engine.stop_profile()
|
self.engine.stop_profile()
|
||||||
|
|
||||||
|
def reset_prefix_cache(self) -> bool:
|
||||||
|
return self.engine.reset_prefix_cache()
|
||||||
|
|
||||||
|
|
||||||
def signal_handler(*_) -> None:
|
def signal_handler(*_) -> None:
|
||||||
raise KeyboardInterrupt("MQLLMEngine terminated")
|
raise KeyboardInterrupt("MQLLMEngine terminated")
|
||||||
|
|||||||
@ -271,6 +271,11 @@ class EngineClient(ABC):
|
|||||||
"""Start profiling the engine"""
|
"""Start profiling the engine"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def reset_prefix_cache(self) -> None:
|
||||||
|
"""Reset the prefix cache"""
|
||||||
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def add_lora(self, lora_request: LoRARequest) -> None:
|
async def add_lora(self, lora_request: LoRARequest) -> None:
|
||||||
"""Load a new LoRA adapter into the engine for future requests."""
|
"""Load a new LoRA adapter into the engine for future requests."""
|
||||||
|
|||||||
@ -1132,6 +1132,9 @@ class LLM:
|
|||||||
def stop_profile(self) -> None:
|
def stop_profile(self) -> None:
|
||||||
self.llm_engine.stop_profile()
|
self.llm_engine.stop_profile()
|
||||||
|
|
||||||
|
def reset_prefix_cache(self) -> bool:
|
||||||
|
return self.llm_engine.reset_prefix_cache()
|
||||||
|
|
||||||
def sleep(self, level: int = 1):
|
def sleep(self, level: int = 1):
|
||||||
"""
|
"""
|
||||||
Put the engine to sleep. The engine should not process any requests.
|
Put the engine to sleep. The engine should not process any requests.
|
||||||
@ -1150,6 +1153,7 @@ class LLM:
|
|||||||
where previous model weights are not needed. It reduces CPU memory
|
where previous model weights are not needed. It reduces CPU memory
|
||||||
pressure.
|
pressure.
|
||||||
"""
|
"""
|
||||||
|
self.reset_prefix_cache()
|
||||||
self.llm_engine.sleep(level=level)
|
self.llm_engine.sleep(level=level)
|
||||||
|
|
||||||
def wake_up(self):
|
def wake_up(self):
|
||||||
|
|||||||
@ -518,6 +518,18 @@ TASK_HANDLERS: Dict[str, Dict[str, tuple]] = {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if envs.VLLM_SERVER_DEV_MODE:
|
||||||
|
|
||||||
|
@router.post("/reset_prefix_cache")
|
||||||
|
async def reset_prefix_cache(raw_request: Request):
|
||||||
|
"""
|
||||||
|
Reset the prefix cache. Note that we currently do not check if the
|
||||||
|
prefix cache is successfully reset in the API server.
|
||||||
|
"""
|
||||||
|
logger.info("Resetting prefix cache...")
|
||||||
|
await engine_client(raw_request).reset_prefix_cache()
|
||||||
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/invocations")
|
@router.post("/invocations")
|
||||||
async def invocations(raw_request: Request):
|
async def invocations(raw_request: Request):
|
||||||
|
|||||||
@ -72,6 +72,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
|
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
|
||||||
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
|
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
|
||||||
VLLM_DISABLE_COMPILE_CACHE: bool = False
|
VLLM_DISABLE_COMPILE_CACHE: bool = False
|
||||||
|
VLLM_SERVER_DEV_MODE: bool = False
|
||||||
|
|
||||||
|
|
||||||
def get_default_cache_root():
|
def get_default_cache_root():
|
||||||
@ -467,6 +468,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
|||||||
lambda: float(os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1")),
|
lambda: float(os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1")),
|
||||||
"VLLM_DISABLE_COMPILE_CACHE":
|
"VLLM_DISABLE_COMPILE_CACHE":
|
||||||
lambda: bool(int(os.getenv("VLLM_DISABLE_COMPILE_CACHE", "0"))),
|
lambda: bool(int(os.getenv("VLLM_DISABLE_COMPILE_CACHE", "0"))),
|
||||||
|
|
||||||
|
# If set, vllm will run in development mode, which will enable
|
||||||
|
# some additional endpoints for developing and debugging,
|
||||||
|
# e.g. `/reset_prefix_cache`
|
||||||
|
"VLLM_SERVER_DEV_MODE":
|
||||||
|
lambda: bool(int(os.getenv("VLLM_SERVER_DEV_MODE", "0"))),
|
||||||
}
|
}
|
||||||
|
|
||||||
# end-env-vars-definition
|
# end-env-vars-definition
|
||||||
|
|||||||
@ -194,11 +194,6 @@ class ExecutorBase(ABC):
|
|||||||
self.collective_rpc("stop_profile")
|
self.collective_rpc("stop_profile")
|
||||||
|
|
||||||
def sleep(self, level: int = 1):
|
def sleep(self, level: int = 1):
|
||||||
if self.cache_config.enable_prefix_caching:
|
|
||||||
# TODO: support sleep with prefix caching
|
|
||||||
# by resetting the prefix cache state,
|
|
||||||
# after https://github.com/vllm-project/vllm/pull/12284
|
|
||||||
raise ValueError("Cannot sleep when prefix caching is enabled.")
|
|
||||||
self.collective_rpc("sleep", kwargs=dict(level=level))
|
self.collective_rpc("sleep", kwargs=dict(level=level))
|
||||||
|
|
||||||
def wake_up(self):
|
def wake_up(self):
|
||||||
|
|||||||
@ -285,6 +285,33 @@ class KVCacheManager:
|
|||||||
if block.ref_cnt == 0:
|
if block.ref_cnt == 0:
|
||||||
self.free_block_queue.append(block)
|
self.free_block_queue.append(block)
|
||||||
|
|
||||||
|
def reset_prefix_cache(self) -> bool:
|
||||||
|
"""Reset prefix cache. This function may be used in RLHF
|
||||||
|
flows to invalid prefix caching after the weights are updated,
|
||||||
|
or used for resetting prefix caching status for benchmarking.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the prefix cache is successfully reset,
|
||||||
|
False otherwise.
|
||||||
|
"""
|
||||||
|
num_used_blocks = (self.num_gpu_blocks -
|
||||||
|
self.free_block_queue.num_free_blocks)
|
||||||
|
if num_used_blocks > 0:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to reset prefix cache because some "
|
||||||
|
"blocks (%d) are not freed yet", num_used_blocks)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Remove all hashes so that no new blocks will hit.
|
||||||
|
self.cached_block_hash_to_block = defaultdict(dict)
|
||||||
|
|
||||||
|
# Remove all hashes from all blocks.
|
||||||
|
for block in self.block_pool:
|
||||||
|
block.reset_hash()
|
||||||
|
|
||||||
|
logger.info("Successfully reset prefix cache")
|
||||||
|
return True
|
||||||
|
|
||||||
def get_num_common_prefix_blocks(
|
def get_num_common_prefix_blocks(
|
||||||
self,
|
self,
|
||||||
request: Request,
|
request: Request,
|
||||||
|
|||||||
@ -529,6 +529,9 @@ class Scheduler:
|
|||||||
def has_unfinished_requests(self) -> bool:
|
def has_unfinished_requests(self) -> bool:
|
||||||
return self.get_num_unfinished_requests() > 0
|
return self.get_num_unfinished_requests() > 0
|
||||||
|
|
||||||
|
def reset_prefix_cache(self) -> bool:
|
||||||
|
return self.kv_cache_manager.reset_prefix_cache()
|
||||||
|
|
||||||
def make_stats(self) -> SchedulerStats:
|
def make_stats(self) -> SchedulerStats:
|
||||||
return SchedulerStats(
|
return SchedulerStats(
|
||||||
num_running_reqs=len(self.running),
|
num_running_reqs=len(self.running),
|
||||||
|
|||||||
@ -66,6 +66,11 @@ class EngineCoreProfile:
|
|||||||
is_start: bool
|
is_start: bool
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EngineCoreResetPrefixCache:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class EngineCoreRequestType(enum.Enum):
|
class EngineCoreRequestType(enum.Enum):
|
||||||
"""
|
"""
|
||||||
Request types defined as hex byte strings, so it can be sent over sockets
|
Request types defined as hex byte strings, so it can be sent over sockets
|
||||||
@ -74,6 +79,8 @@ class EngineCoreRequestType(enum.Enum):
|
|||||||
ADD = b'\x00'
|
ADD = b'\x00'
|
||||||
ABORT = b'\x01'
|
ABORT = b'\x01'
|
||||||
PROFILE = b'\x02'
|
PROFILE = b'\x02'
|
||||||
|
RESET_PREFIX_CACHE = b'\x03'
|
||||||
|
|
||||||
|
|
||||||
EngineCoreRequestUnion = Union[EngineCoreRequest, EngineCoreProfile, List[str]]
|
EngineCoreRequestUnion = Union[EngineCoreRequest, EngineCoreProfile,
|
||||||
|
EngineCoreResetPrefixCache, List[str]]
|
||||||
|
|||||||
@ -321,6 +321,9 @@ class AsyncLLM(EngineClient):
|
|||||||
async def stop_profile(self) -> None:
|
async def stop_profile(self) -> None:
|
||||||
await self.engine_core.profile_async(False)
|
await self.engine_core.profile_async(False)
|
||||||
|
|
||||||
|
async def reset_prefix_cache(self) -> None:
|
||||||
|
await self.engine_core.reset_prefix_cache_async()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_running(self) -> bool:
|
def is_running(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|||||||
@ -20,7 +20,7 @@ from vllm.v1.core.kv_cache_utils import get_kv_cache_config
|
|||||||
from vllm.v1.core.scheduler import Scheduler
|
from vllm.v1.core.scheduler import Scheduler
|
||||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
|
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
|
||||||
EngineCoreRequest, EngineCoreRequestType,
|
EngineCoreRequest, EngineCoreRequestType,
|
||||||
EngineCoreRequestUnion)
|
EngineCoreRequestUnion, EngineCoreResetPrefixCache)
|
||||||
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
|
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
|
||||||
from vllm.v1.executor.abstract import Executor
|
from vllm.v1.executor.abstract import Executor
|
||||||
from vllm.v1.request import Request, RequestStatus
|
from vllm.v1.request import Request, RequestStatus
|
||||||
@ -135,6 +135,9 @@ class EngineCore:
|
|||||||
def profile(self, is_start: bool = True):
|
def profile(self, is_start: bool = True):
|
||||||
self.model_executor.profile(is_start)
|
self.model_executor.profile(is_start)
|
||||||
|
|
||||||
|
def reset_prefix_cache(self):
|
||||||
|
self.scheduler.reset_prefix_cache()
|
||||||
|
|
||||||
|
|
||||||
class EngineCoreProc(EngineCore):
|
class EngineCoreProc(EngineCore):
|
||||||
"""ZMQ-wrapper for running EngineCore in background process."""
|
"""ZMQ-wrapper for running EngineCore in background process."""
|
||||||
@ -247,6 +250,8 @@ class EngineCoreProc(EngineCore):
|
|||||||
self.add_request(request)
|
self.add_request(request)
|
||||||
elif isinstance(request, EngineCoreProfile):
|
elif isinstance(request, EngineCoreProfile):
|
||||||
self.model_executor.profile(request.is_start)
|
self.model_executor.profile(request.is_start)
|
||||||
|
elif isinstance(request, EngineCoreResetPrefixCache):
|
||||||
|
self.reset_prefix_cache()
|
||||||
else:
|
else:
|
||||||
# TODO: make an EngineCoreAbort wrapper
|
# TODO: make an EngineCoreAbort wrapper
|
||||||
assert isinstance(request, list)
|
assert isinstance(request, list)
|
||||||
@ -271,7 +276,9 @@ class EngineCoreProc(EngineCore):
|
|||||||
request = decoder_add_req.decode(request_data)
|
request = decoder_add_req.decode(request_data)
|
||||||
elif request_type == EngineCoreRequestType.ABORT.value:
|
elif request_type == EngineCoreRequestType.ABORT.value:
|
||||||
request = decoder_abort_req.decode(request_data)
|
request = decoder_abort_req.decode(request_data)
|
||||||
elif request_type == EngineCoreRequestType.PROFILE.value:
|
elif request_type in (
|
||||||
|
EngineCoreRequestType.PROFILE.value,
|
||||||
|
EngineCoreRequestType.RESET_PREFIX_CACHE.value):
|
||||||
request = pickle.loads(request_data)
|
request = pickle.loads(request_data)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown RequestType: {request_type}")
|
raise ValueError(f"Unknown RequestType: {request_type}")
|
||||||
|
|||||||
@ -14,7 +14,7 @@ from vllm.utils import (get_open_zmq_ipc_path, kill_process_tree,
|
|||||||
make_zmq_socket)
|
make_zmq_socket)
|
||||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
|
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
|
||||||
EngineCoreRequest, EngineCoreRequestType,
|
EngineCoreRequest, EngineCoreRequestType,
|
||||||
EngineCoreRequestUnion)
|
EngineCoreRequestUnion, EngineCoreResetPrefixCache)
|
||||||
from vllm.v1.engine.core import EngineCore, EngineCoreProc
|
from vllm.v1.engine.core import EngineCore, EngineCoreProc
|
||||||
from vllm.v1.executor.abstract import Executor
|
from vllm.v1.executor.abstract import Executor
|
||||||
from vllm.v1.serial_utils import PickleEncoder
|
from vllm.v1.serial_utils import PickleEncoder
|
||||||
@ -69,6 +69,9 @@ class EngineCoreClient(ABC):
|
|||||||
def profile(self, is_start: bool = True) -> None:
|
def profile(self, is_start: bool = True) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def reset_prefix_cache(self) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def abort_requests(self, request_ids: List[str]) -> None:
|
def abort_requests(self, request_ids: List[str]) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -81,6 +84,9 @@ class EngineCoreClient(ABC):
|
|||||||
async def profile_async(self, is_start: bool = True) -> None:
|
async def profile_async(self, is_start: bool = True) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def reset_prefix_cache_async(self) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
async def abort_requests_async(self, request_ids: List[str]) -> None:
|
async def abort_requests_async(self, request_ids: List[str]) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -108,12 +114,15 @@ class InprocClient(EngineCoreClient):
|
|||||||
if len(request_ids) > 0:
|
if len(request_ids) > 0:
|
||||||
self.engine_core.abort_requests(request_ids)
|
self.engine_core.abort_requests(request_ids)
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self) -> None:
|
||||||
self.engine_core.shutdown()
|
self.engine_core.shutdown()
|
||||||
|
|
||||||
def profile(self, is_start: bool = True) -> None:
|
def profile(self, is_start: bool = True) -> None:
|
||||||
self.engine_core.profile(is_start)
|
self.engine_core.profile(is_start)
|
||||||
|
|
||||||
|
def reset_prefix_cache(self) -> None:
|
||||||
|
self.engine_core.reset_prefix_cache()
|
||||||
|
|
||||||
|
|
||||||
class MPClient(EngineCoreClient):
|
class MPClient(EngineCoreClient):
|
||||||
"""
|
"""
|
||||||
@ -229,6 +238,10 @@ class SyncMPClient(MPClient):
|
|||||||
self._send_input(EngineCoreRequestType.PROFILE,
|
self._send_input(EngineCoreRequestType.PROFILE,
|
||||||
EngineCoreProfile(is_start))
|
EngineCoreProfile(is_start))
|
||||||
|
|
||||||
|
def reset_prefix_cache(self) -> None:
|
||||||
|
self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE,
|
||||||
|
EngineCoreResetPrefixCache())
|
||||||
|
|
||||||
|
|
||||||
class AsyncMPClient(MPClient):
|
class AsyncMPClient(MPClient):
|
||||||
"""Asyncio-compatible client for multi-proc EngineCore."""
|
"""Asyncio-compatible client for multi-proc EngineCore."""
|
||||||
@ -266,3 +279,7 @@ class AsyncMPClient(MPClient):
|
|||||||
async def profile_async(self, is_start: bool = True) -> None:
|
async def profile_async(self, is_start: bool = True) -> None:
|
||||||
await self._send_input(EngineCoreRequestType.PROFILE,
|
await self._send_input(EngineCoreRequestType.PROFILE,
|
||||||
EngineCoreProfile(is_start))
|
EngineCoreProfile(is_start))
|
||||||
|
|
||||||
|
async def reset_prefix_cache_async(self) -> None:
|
||||||
|
await self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE,
|
||||||
|
EngineCoreResetPrefixCache())
|
||||||
|
|||||||
@ -162,6 +162,9 @@ class LLMEngine:
|
|||||||
def stop_profile(self):
|
def stop_profile(self):
|
||||||
self.engine_core.profile(False)
|
self.engine_core.profile(False)
|
||||||
|
|
||||||
|
def reset_prefix_cache(self):
|
||||||
|
self.engine_core.reset_prefix_cache()
|
||||||
|
|
||||||
def get_tokenizer_group(
|
def get_tokenizer_group(
|
||||||
self,
|
self,
|
||||||
group_type: Type[_G] = BaseTokenizerGroup,
|
group_type: Type[_G] = BaseTokenizerGroup,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user