[FEAT]Support reset prefix cache by specified device (#15003)

This commit is contained in:
maobaolong 2025-03-20 01:54:41 +08:00 committed by GitHub
parent 61c7a1b856
commit 26dd972adb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 49 additions and 34 deletions

View File

@ -341,8 +341,10 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
assert device in self._allocators assert device in self._allocators
return self._allocators[device].get_prefix_cache_hit_rate() return self._allocators[device].get_prefix_cache_hit_rate()
def reset_prefix_cache(self) -> bool: def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
"""Reset prefix cache for all devices.""" """Reset prefix cache for specified or all devices."""
if device:
return self._allocators[device].reset_prefix_cache()
success = True success = True
for allocator in self._allocators.values(): for allocator in self._allocators.values():
success = success and allocator.reset_prefix_cache() success = success and allocator.reset_prefix_cache()

View File

@ -305,7 +305,7 @@ class DeviceAwareBlockAllocator(ABC):
pass pass
@abstractmethod @abstractmethod
def reset_prefix_cache(self) -> bool: def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
"""Reset prefix cache.""" """Reset prefix cache."""
pass pass

View File

@ -456,8 +456,8 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
def get_prefix_cache_hit_rate(self, device: Device) -> float: def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_allocator.get_prefix_cache_hit_rate(device) return self.block_allocator.get_prefix_cache_hit_rate(device)
def reset_prefix_cache(self) -> bool: def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
return self.block_allocator.reset_prefix_cache() return self.block_allocator.reset_prefix_cache(device)
def _can_swap(self, def _can_swap(self,
seq_group: SequenceGroup, seq_group: SequenceGroup,

View File

@ -2,7 +2,7 @@
import enum import enum
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List from typing import List, Optional
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Tuple from typing import Tuple
@ -125,8 +125,8 @@ class BlockSpaceManager(ABC):
pass pass
@abstractmethod @abstractmethod
def reset_prefix_cache(self) -> bool: def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
"""Reset prefix cache for all devices.""" """Reset prefix cache for specified or all devices."""
pass pass
@abstractmethod @abstractmethod

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import List, Tuple from typing import List, Optional, Tuple
from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.sequence import Sequence, SequenceGroup from vllm.sequence import Sequence, SequenceGroup
@ -92,7 +92,7 @@ class PlaceholderBlockSpaceManager(BlockSpaceManager):
def get_prefix_cache_hit_rate(self, device: Device) -> float: def get_prefix_cache_hit_rate(self, device: Device) -> float:
return -1 return -1
def reset_prefix_cache(self) -> bool: def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
return True return True
def get_num_cached_tokens(self, seq: Sequence) -> int: def get_num_cached_tokens(self, seq: Sequence) -> int:

View File

@ -634,8 +634,8 @@ class Scheduler:
def get_prefix_cache_hit_rate(self, device: Device) -> float: def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_manager.get_prefix_cache_hit_rate(device) return self.block_manager.get_prefix_cache_hit_rate(device)
def reset_prefix_cache(self) -> bool: def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
return self.block_manager.reset_prefix_cache() return self.block_manager.reset_prefix_cache(device)
def get_num_unfinished_seq_groups(self) -> int: def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped) return len(self.waiting) + len(self.running) + len(self.swapped)

View File

@ -35,7 +35,7 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import deprecate_kwargs, weak_bind from vllm.utils import Device, deprecate_kwargs, weak_bind
logger = init_logger(__name__) logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
@ -1216,8 +1216,9 @@ class AsyncLLMEngine(EngineClient):
async def stop_profile(self) -> None: async def stop_profile(self) -> None:
self.engine.stop_profile() self.engine.stop_profile()
async def reset_prefix_cache(self) -> None: async def reset_prefix_cache(self,
self.engine.reset_prefix_cache() device: Optional[Device] = None) -> None:
self.engine.reset_prefix_cache(device)
async def sleep(self, level: int = 1) -> None: async def sleep(self, level: int = 1) -> None:
self.engine.sleep(level) self.engine.sleep(level)

View File

@ -955,12 +955,12 @@ class LLMEngine:
""" """
return self.scheduler[virtual_engine].has_unfinished_seqs() return self.scheduler[virtual_engine].has_unfinished_seqs()
def reset_prefix_cache(self) -> bool: def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
"""Reset prefix cache for all devices.""" """Reset prefix cache for all devices."""
success = True success = True
for scheduler in self.scheduler: for scheduler in self.scheduler:
success = success and scheduler.reset_prefix_cache() success = success and scheduler.reset_prefix_cache(device)
return success return success
@staticmethod @staticmethod

View File

@ -13,7 +13,7 @@ from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import deprecate_kwargs from vllm.utils import Device, deprecate_kwargs
VLLM_RPC_SUCCESS_STR = "SUCCESS" VLLM_RPC_SUCCESS_STR = "SUCCESS"
@ -123,8 +123,9 @@ class RPCUProfileRequest(Enum):
STOP_PROFILE = 2 STOP_PROFILE = 2
class RPCResetPrefixCacheRequest(Enum): @dataclass
RESET_PREFIX_CACHE = 1 class RPCResetPrefixCacheRequest:
device: Device
class RPCSleepRequest(Enum): class RPCSleepRequest(Enum):

View File

@ -47,7 +47,7 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import deprecate_kwargs from vllm.utils import Device, deprecate_kwargs
logger = init_logger(__name__) logger = init_logger(__name__)
@ -684,11 +684,12 @@ class MQLLMEngineClient(EngineClient):
await self._send_one_way_rpc_request( await self._send_one_way_rpc_request(
request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket) request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket)
async def reset_prefix_cache(self) -> None: async def reset_prefix_cache(self,
device: Optional[Device] = None) -> None:
"""Reset the prefix cache""" """Reset the prefix cache"""
await self._send_one_way_rpc_request( await self._send_one_way_rpc_request(
request=RPCResetPrefixCacheRequest.RESET_PREFIX_CACHE, request=RPCResetPrefixCacheRequest(device),
socket=self.input_socket) socket=self.input_socket)
async def sleep(self, level: int = 1) -> None: async def sleep(self, level: int = 1) -> None:

View File

@ -18,7 +18,7 @@ from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import collect_from_async_generator, random_uuid from vllm.utils import Device, collect_from_async_generator, random_uuid
logger = init_logger(__name__) logger = init_logger(__name__)
@ -274,7 +274,8 @@ class EngineClient(ABC):
... ...
@abstractmethod @abstractmethod
async def reset_prefix_cache(self) -> None: async def reset_prefix_cache(self,
device: Optional[Device] = None) -> None:
"""Reset the prefix cache""" """Reset the prefix cache"""
... ...

View File

@ -42,7 +42,8 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer) get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs,
is_list_of)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -1187,8 +1188,8 @@ class LLM:
def stop_profile(self) -> None: def stop_profile(self) -> None:
self.llm_engine.stop_profile() self.llm_engine.stop_profile()
def reset_prefix_cache(self) -> bool: def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
return self.llm_engine.reset_prefix_cache() return self.llm_engine.reset_prefix_cache(device)
def sleep(self, level: int = 1): def sleep(self, level: int = 1):
""" """

View File

@ -85,7 +85,7 @@ from vllm.logger import init_logger
from vllm.transformers_utils.config import ( from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value) maybe_register_config_serialize_by_value)
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path, from vllm.utils import (Device, FlexibleArgumentParser, get_open_zmq_ipc_path,
is_valid_ipv6_address, set_ulimit) is_valid_ipv6_address, set_ulimit)
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
@ -677,8 +677,12 @@ if envs.VLLM_SERVER_DEV_MODE:
Reset the prefix cache. Note that we currently do not check if the Reset the prefix cache. Note that we currently do not check if the
prefix cache is successfully reset in the API server. prefix cache is successfully reset in the API server.
""" """
logger.info("Resetting prefix cache...") device = None
await engine_client(raw_request).reset_prefix_cache() device_str = raw_request.query_params.get("device")
if device_str is not None:
device = Device[device_str.upper()]
logger.info("Resetting prefix cache with specific %s...", str(device))
await engine_client(raw_request).reset_prefix_cache(device)
return Response(status_code=200) return Response(status_code=200)
@router.post("/sleep") @router.post("/sleep")

View File

@ -24,7 +24,7 @@ from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import cdiv, kill_process_tree from vllm.utils import Device, cdiv, kill_process_tree
from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.output_processor import OutputProcessor from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.parallel_sampling import ParentRequest
@ -398,7 +398,10 @@ class AsyncLLM(EngineClient):
async def stop_profile(self) -> None: async def stop_profile(self) -> None:
await self.engine_core.profile_async(False) await self.engine_core.profile_async(False)
async def reset_prefix_cache(self) -> None: async def reset_prefix_cache(self,
device: Optional[Device] = None) -> None:
if device == Device.CPU:
raise ValueError("Not supported on CPU.")
await self.engine_core.reset_prefix_cache_async() await self.engine_core.reset_prefix_cache_async()
async def sleep(self, level: int = 1) -> None: async def sleep(self, level: int = 1) -> None:

View File

@ -20,6 +20,7 @@ from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import ( from vllm.transformers_utils.tokenizer_group import (
BaseTokenizerGroup, init_tokenizer_from_configs) BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device
from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.output_processor import OutputProcessor from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.parallel_sampling import ParentRequest
@ -226,7 +227,7 @@ class LLMEngine:
def stop_profile(self): def stop_profile(self):
self.engine_core.profile(False) self.engine_core.profile(False)
def reset_prefix_cache(self): def reset_prefix_cache(self, device: Optional[Device] = None):
self.engine_core.reset_prefix_cache() self.engine_core.reset_prefix_cache()
def sleep(self, level: int = 1): def sleep(self, level: int = 1):