[RL] [V1] Remove unused device argument from reset_kv_cache (#28766)

Signed-off-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
Zhuohan Li 2025-11-14 23:59:42 -08:00 committed by GitHub
parent 98b4d389ed
commit dd6ac1c2bb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 9 additions and 17 deletions

View File

@ -125,7 +125,7 @@ class EngineClient(ABC):
... ...
@abstractmethod @abstractmethod
async def reset_prefix_cache(self, device: Device | None = None) -> None: async def reset_prefix_cache(self) -> None:
"""Reset the prefix cache""" """Reset the prefix cache"""
... ...

View File

@ -32,7 +32,6 @@ from vllm.config.model import (
TokenizerMode, TokenizerMode,
) )
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.protocol import Device
from vllm.entrypoints.chat_utils import ( from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam, ChatCompletionMessageParam,
ChatTemplateContentFormatOption, ChatTemplateContentFormatOption,
@ -1499,8 +1498,8 @@ class LLM:
def stop_profile(self) -> None: def stop_profile(self) -> None:
self.llm_engine.stop_profile() self.llm_engine.stop_profile()
def reset_prefix_cache(self, device: Device | None = None) -> None: def reset_prefix_cache(self) -> None:
self.llm_engine.reset_prefix_cache(device) self.llm_engine.reset_prefix_cache()
def sleep(self, level: int = 1): def sleep(self, level: int = 1):
""" """

View File

@ -39,7 +39,7 @@ from typing_extensions import assert_never
import vllm.envs as envs import vllm.envs as envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import Device, EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.anthropic.protocol import ( from vllm.entrypoints.anthropic.protocol import (
AnthropicError, AnthropicError,
AnthropicErrorResponse, AnthropicErrorResponse,
@ -1069,12 +1069,8 @@ if envs.VLLM_SERVER_DEV_MODE:
Reset the prefix cache. Note that we currently do not check if the Reset the prefix cache. Note that we currently do not check if the
prefix cache is successfully reset in the API server. prefix cache is successfully reset in the API server.
""" """
device = None logger.info("Resetting prefix cache...")
device_str = raw_request.query_params.get("device") await engine_client(raw_request).reset_prefix_cache()
if device_str is not None:
device = Device[device_str.upper()]
logger.info("Resetting prefix cache with specific %s...", str(device))
await engine_client(raw_request).reset_prefix_cache(device)
return Response(status_code=200) return Response(status_code=200)
@router.post("/reset_mm_cache") @router.post("/reset_mm_cache")

View File

@ -14,7 +14,7 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import Device, EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.utils import _validate_truncation_size from vllm.entrypoints.utils import _validate_truncation_size
from vllm.inputs import PromptType from vllm.inputs import PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
@ -672,9 +672,7 @@ class AsyncLLM(EngineClient):
self.processor.clear_mm_cache() self.processor.clear_mm_cache()
await self.engine_core.reset_mm_cache_async() await self.engine_core.reset_mm_cache_async()
async def reset_prefix_cache(self, device: Device | None = None) -> None: async def reset_prefix_cache(self) -> None:
if device == Device.CPU:
raise ValueError("Not supported on CPU.")
await self.engine_core.reset_prefix_cache_async() await self.engine_core.reset_prefix_cache_async()
async def sleep(self, level: int = 1) -> None: async def sleep(self, level: int = 1) -> None:

View File

@ -14,7 +14,6 @@ from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.distributed.parallel_state import get_dp_group from vllm.distributed.parallel_state import get_dp_group
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.protocol import Device
from vllm.inputs import PromptType from vllm.inputs import PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
@ -321,7 +320,7 @@ class LLMEngine:
self.processor.clear_mm_cache() self.processor.clear_mm_cache()
self.engine_core.reset_mm_cache() self.engine_core.reset_mm_cache()
def reset_prefix_cache(self, device: Device | None = None): def reset_prefix_cache(self):
self.engine_core.reset_prefix_cache() self.engine_core.reset_prefix_cache()
def sleep(self, level: int = 1): def sleep(self, level: int = 1):