mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 08:25:01 +08:00
[FEAT]Support reset prefix cache by specified device (#15003)
This commit is contained in:
parent
61c7a1b856
commit
26dd972adb
@ -341,8 +341,10 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
|
||||
assert device in self._allocators
|
||||
return self._allocators[device].get_prefix_cache_hit_rate()
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
"""Reset prefix cache for all devices."""
|
||||
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
|
||||
"""Reset prefix cache for specified or all devices."""
|
||||
if device:
|
||||
return self._allocators[device].reset_prefix_cache()
|
||||
success = True
|
||||
for allocator in self._allocators.values():
|
||||
success = success and allocator.reset_prefix_cache()
|
||||
|
||||
@ -305,7 +305,7 @@ class DeviceAwareBlockAllocator(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
|
||||
"""Reset prefix cache."""
|
||||
pass
|
||||
|
||||
|
||||
@ -456,8 +456,8 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
|
||||
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
||||
return self.block_allocator.get_prefix_cache_hit_rate(device)
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
return self.block_allocator.reset_prefix_cache()
|
||||
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
|
||||
return self.block_allocator.reset_prefix_cache(device)
|
||||
|
||||
def _can_swap(self,
|
||||
seq_group: SequenceGroup,
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Tuple
|
||||
|
||||
@ -125,8 +125,8 @@ class BlockSpaceManager(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
"""Reset prefix cache for all devices."""
|
||||
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
|
||||
"""Reset prefix cache for specified or all devices."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import List, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
|
||||
from vllm.sequence import Sequence, SequenceGroup
|
||||
@ -92,7 +92,7 @@ class PlaceholderBlockSpaceManager(BlockSpaceManager):
|
||||
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
||||
return -1
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
|
||||
return True
|
||||
|
||||
def get_num_cached_tokens(self, seq: Sequence) -> int:
|
||||
|
||||
@ -634,8 +634,8 @@ class Scheduler:
|
||||
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
||||
return self.block_manager.get_prefix_cache_hit_rate(device)
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
return self.block_manager.reset_prefix_cache()
|
||||
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
|
||||
return self.block_manager.reset_prefix_cache(device)
|
||||
|
||||
def get_num_unfinished_seq_groups(self) -> int:
|
||||
return len(self.waiting) + len(self.running) + len(self.swapped)
|
||||
|
||||
@ -35,7 +35,7 @@ from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import deprecate_kwargs, weak_bind
|
||||
from vllm.utils import Device, deprecate_kwargs, weak_bind
|
||||
|
||||
logger = init_logger(__name__)
|
||||
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
|
||||
@ -1216,8 +1216,9 @@ class AsyncLLMEngine(EngineClient):
|
||||
async def stop_profile(self) -> None:
|
||||
self.engine.stop_profile()
|
||||
|
||||
async def reset_prefix_cache(self) -> None:
|
||||
self.engine.reset_prefix_cache()
|
||||
async def reset_prefix_cache(self,
|
||||
device: Optional[Device] = None) -> None:
|
||||
self.engine.reset_prefix_cache(device)
|
||||
|
||||
async def sleep(self, level: int = 1) -> None:
|
||||
self.engine.sleep(level)
|
||||
|
||||
@ -955,12 +955,12 @@ class LLMEngine:
|
||||
"""
|
||||
return self.scheduler[virtual_engine].has_unfinished_seqs()
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
|
||||
"""Reset prefix cache for all devices."""
|
||||
|
||||
success = True
|
||||
for scheduler in self.scheduler:
|
||||
success = success and scheduler.reset_prefix_cache()
|
||||
success = success and scheduler.reset_prefix_cache(device)
|
||||
return success
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -13,7 +13,7 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import deprecate_kwargs
|
||||
from vllm.utils import Device, deprecate_kwargs
|
||||
|
||||
VLLM_RPC_SUCCESS_STR = "SUCCESS"
|
||||
|
||||
@ -123,8 +123,9 @@ class RPCUProfileRequest(Enum):
|
||||
STOP_PROFILE = 2
|
||||
|
||||
|
||||
class RPCResetPrefixCacheRequest(Enum):
|
||||
RESET_PREFIX_CACHE = 1
|
||||
@dataclass
|
||||
class RPCResetPrefixCacheRequest:
|
||||
device: Device
|
||||
|
||||
|
||||
class RPCSleepRequest(Enum):
|
||||
|
||||
@ -47,7 +47,7 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
from vllm.utils import deprecate_kwargs
|
||||
from vllm.utils import Device, deprecate_kwargs
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -684,11 +684,12 @@ class MQLLMEngineClient(EngineClient):
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket)
|
||||
|
||||
async def reset_prefix_cache(self) -> None:
|
||||
async def reset_prefix_cache(self,
|
||||
device: Optional[Device] = None) -> None:
|
||||
"""Reset the prefix cache"""
|
||||
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCResetPrefixCacheRequest.RESET_PREFIX_CACHE,
|
||||
request=RPCResetPrefixCacheRequest(device),
|
||||
socket=self.input_socket)
|
||||
|
||||
async def sleep(self, level: int = 1) -> None:
|
||||
|
||||
@ -18,7 +18,7 @@ from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import collect_from_async_generator, random_uuid
|
||||
from vllm.utils import Device, collect_from_async_generator, random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -274,7 +274,8 @@ class EngineClient(ABC):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def reset_prefix_cache(self) -> None:
|
||||
async def reset_prefix_cache(self,
|
||||
device: Optional[Device] = None) -> None:
|
||||
"""Reset the prefix cache"""
|
||||
...
|
||||
|
||||
|
||||
@ -42,7 +42,8 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
|
||||
get_cached_tokenizer)
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of
|
||||
from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs,
|
||||
is_list_of)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -1187,8 +1188,8 @@ class LLM:
|
||||
def stop_profile(self) -> None:
|
||||
self.llm_engine.stop_profile()
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
return self.llm_engine.reset_prefix_cache()
|
||||
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
|
||||
return self.llm_engine.reset_prefix_cache(device)
|
||||
|
||||
def sleep(self, level: int = 1):
|
||||
"""
|
||||
|
||||
@ -85,7 +85,7 @@ from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.config import (
|
||||
maybe_register_config_serialize_by_value)
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path,
|
||||
from vllm.utils import (Device, FlexibleArgumentParser, get_open_zmq_ipc_path,
|
||||
is_valid_ipv6_address, set_ulimit)
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
@ -677,8 +677,12 @@ if envs.VLLM_SERVER_DEV_MODE:
|
||||
Reset the prefix cache. Note that we currently do not check if the
|
||||
prefix cache is successfully reset in the API server.
|
||||
"""
|
||||
logger.info("Resetting prefix cache...")
|
||||
await engine_client(raw_request).reset_prefix_cache()
|
||||
device = None
|
||||
device_str = raw_request.query_params.get("device")
|
||||
if device_str is not None:
|
||||
device = Device[device_str.upper()]
|
||||
logger.info("Resetting prefix cache with specific %s...", str(device))
|
||||
await engine_client(raw_request).reset_prefix_cache(device)
|
||||
return Response(status_code=200)
|
||||
|
||||
@router.post("/sleep")
|
||||
|
||||
@ -24,7 +24,7 @@ from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import cdiv, kill_process_tree
|
||||
from vllm.utils import Device, cdiv, kill_process_tree
|
||||
from vllm.v1.engine.core_client import EngineCoreClient
|
||||
from vllm.v1.engine.output_processor import OutputProcessor
|
||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||
@ -398,7 +398,10 @@ class AsyncLLM(EngineClient):
|
||||
async def stop_profile(self) -> None:
|
||||
await self.engine_core.profile_async(False)
|
||||
|
||||
async def reset_prefix_cache(self) -> None:
|
||||
async def reset_prefix_cache(self,
|
||||
device: Optional[Device] = None) -> None:
|
||||
if device == Device.CPU:
|
||||
raise ValueError("Not supported on CPU.")
|
||||
await self.engine_core.reset_prefix_cache_async()
|
||||
|
||||
async def sleep(self, level: int = 1) -> None:
|
||||
|
||||
@ -20,6 +20,7 @@ from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer_group import (
|
||||
BaseTokenizerGroup, init_tokenizer_from_configs)
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Device
|
||||
from vllm.v1.engine.core_client import EngineCoreClient
|
||||
from vllm.v1.engine.output_processor import OutputProcessor
|
||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||
@ -226,7 +227,7 @@ class LLMEngine:
|
||||
def stop_profile(self):
|
||||
self.engine_core.profile(False)
|
||||
|
||||
def reset_prefix_cache(self):
|
||||
def reset_prefix_cache(self, device: Optional[Device] = None):
|
||||
self.engine_core.reset_prefix_cache()
|
||||
|
||||
def sleep(self, level: int = 1):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user