[FEAT]Support reset prefix cache by specified device (#15003)

This commit is contained in:
maobaolong 2025-03-20 01:54:41 +08:00 committed by GitHub
parent 61c7a1b856
commit 26dd972adb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 49 additions and 34 deletions

View File

@ -341,8 +341,10 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
assert device in self._allocators
return self._allocators[device].get_prefix_cache_hit_rate()
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache for all devices."""
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
"""Reset prefix cache for specified or all devices."""
if device:
return self._allocators[device].reset_prefix_cache()
success = True
for allocator in self._allocators.values():
success = success and allocator.reset_prefix_cache()

View File

@ -305,7 +305,7 @@ class DeviceAwareBlockAllocator(ABC):
pass
@abstractmethod
def reset_prefix_cache(self) -> bool:
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
"""Reset prefix cache."""
pass

View File

@ -456,8 +456,8 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_allocator.get_prefix_cache_hit_rate(device)
def reset_prefix_cache(self) -> bool:
return self.block_allocator.reset_prefix_cache()
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
return self.block_allocator.reset_prefix_cache(device)
def _can_swap(self,
seq_group: SequenceGroup,

View File

@ -2,7 +2,7 @@
import enum
from abc import ABC, abstractmethod
from typing import List
from typing import List, Optional
from typing import Sequence as GenericSequence
from typing import Tuple
@ -125,8 +125,8 @@ class BlockSpaceManager(ABC):
pass
@abstractmethod
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache for all devices."""
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
"""Reset prefix cache for specified or all devices."""
pass
@abstractmethod

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import List, Tuple
from typing import List, Optional, Tuple
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.sequence import Sequence, SequenceGroup
@ -92,7 +92,7 @@ class PlaceholderBlockSpaceManager(BlockSpaceManager):
def get_prefix_cache_hit_rate(self, device: Device) -> float:
return -1
def reset_prefix_cache(self) -> bool:
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
return True
def get_num_cached_tokens(self, seq: Sequence) -> int:

View File

@ -634,8 +634,8 @@ class Scheduler:
def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_manager.get_prefix_cache_hit_rate(device)
def reset_prefix_cache(self) -> bool:
return self.block_manager.reset_prefix_cache()
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
return self.block_manager.reset_prefix_cache(device)
def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped)

View File

@ -35,7 +35,7 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils import deprecate_kwargs, weak_bind
from vllm.utils import Device, deprecate_kwargs, weak_bind
logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
@ -1216,8 +1216,9 @@ class AsyncLLMEngine(EngineClient):
async def stop_profile(self) -> None:
self.engine.stop_profile()
async def reset_prefix_cache(self) -> None:
self.engine.reset_prefix_cache()
async def reset_prefix_cache(self,
device: Optional[Device] = None) -> None:
self.engine.reset_prefix_cache(device)
async def sleep(self, level: int = 1) -> None:
self.engine.sleep(level)

View File

@ -955,12 +955,12 @@ class LLMEngine:
"""
return self.scheduler[virtual_engine].has_unfinished_seqs()
def reset_prefix_cache(self) -> bool:
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
"""Reset prefix cache for all devices."""
success = True
for scheduler in self.scheduler:
success = success and scheduler.reset_prefix_cache()
success = success and scheduler.reset_prefix_cache(device)
return success
@staticmethod

View File

@ -13,7 +13,7 @@ from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.utils import deprecate_kwargs
from vllm.utils import Device, deprecate_kwargs
VLLM_RPC_SUCCESS_STR = "SUCCESS"
@ -123,8 +123,9 @@ class RPCUProfileRequest(Enum):
STOP_PROFILE = 2
class RPCResetPrefixCacheRequest(Enum):
RESET_PREFIX_CACHE = 1
@dataclass
class RPCResetPrefixCacheRequest:
device: Device
class RPCSleepRequest(Enum):

View File

@ -47,7 +47,7 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import deprecate_kwargs
from vllm.utils import Device, deprecate_kwargs
logger = init_logger(__name__)
@ -684,11 +684,12 @@ class MQLLMEngineClient(EngineClient):
await self._send_one_way_rpc_request(
request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket)
async def reset_prefix_cache(self) -> None:
async def reset_prefix_cache(self,
device: Optional[Device] = None) -> None:
"""Reset the prefix cache"""
await self._send_one_way_rpc_request(
request=RPCResetPrefixCacheRequest.RESET_PREFIX_CACHE,
request=RPCResetPrefixCacheRequest(device),
socket=self.input_socket)
async def sleep(self, level: int = 1) -> None:

View File

@ -18,7 +18,7 @@ from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import collect_from_async_generator, random_uuid
from vllm.utils import Device, collect_from_async_generator, random_uuid
logger = init_logger(__name__)
@ -274,7 +274,8 @@ class EngineClient(ABC):
...
@abstractmethod
async def reset_prefix_cache(self) -> None:
async def reset_prefix_cache(self,
device: Optional[Device] = None) -> None:
"""Reset the prefix cache"""
...

View File

@ -42,7 +42,8 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of
from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs,
is_list_of)
logger = init_logger(__name__)
@ -1187,8 +1188,8 @@ class LLM:
def stop_profile(self) -> None:
self.llm_engine.stop_profile()
def reset_prefix_cache(self) -> bool:
return self.llm_engine.reset_prefix_cache()
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
return self.llm_engine.reset_prefix_cache(device)
def sleep(self, level: int = 1):
"""

View File

@ -85,7 +85,7 @@ from vllm.logger import init_logger
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path,
from vllm.utils import (Device, FlexibleArgumentParser, get_open_zmq_ipc_path,
is_valid_ipv6_address, set_ulimit)
from vllm.version import __version__ as VLLM_VERSION
@ -677,8 +677,12 @@ if envs.VLLM_SERVER_DEV_MODE:
Reset the prefix cache. Note that we currently do not check if the
prefix cache is successfully reset in the API server.
"""
logger.info("Resetting prefix cache...")
await engine_client(raw_request).reset_prefix_cache()
device = None
device_str = raw_request.query_params.get("device")
if device_str is not None:
device = Device[device_str.upper()]
logger.info("Resetting prefix cache with specific %s...", str(device))
await engine_client(raw_request).reset_prefix_cache(device)
return Response(status_code=200)
@router.post("/sleep")

View File

@ -24,7 +24,7 @@ from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
from vllm.utils import cdiv, kill_process_tree
from vllm.utils import Device, cdiv, kill_process_tree
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest
@ -398,7 +398,10 @@ class AsyncLLM(EngineClient):
async def stop_profile(self) -> None:
await self.engine_core.profile_async(False)
async def reset_prefix_cache(self) -> None:
async def reset_prefix_cache(self,
device: Optional[Device] = None) -> None:
if device == Device.CPU:
raise ValueError("Not supported on CPU.")
await self.engine_core.reset_prefix_cache_async()
async def sleep(self, level: int = 1) -> None:

View File

@ -20,6 +20,7 @@ from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import (
BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest
@ -226,7 +227,7 @@ class LLMEngine:
def stop_profile(self):
self.engine_core.profile(False)
def reset_prefix_cache(self):
def reset_prefix_cache(self, device: Optional[Device] = None):
self.engine_core.reset_prefix_cache()
def sleep(self, level: int = 1):