diff --git a/tests/lora/test_add_lora.py b/tests/lora/test_add_lora.py index df8031cba687..2b421bfd9eb8 100644 --- a/tests/lora/test_add_lora.py +++ b/tests/lora/test_add_lora.py @@ -41,7 +41,7 @@ def download_and_prepare_lora_module(): ] for tokenizer_file in tokenizer_files: del_path = Path(LORA_MODULE_DOWNLOAD_PATH) / tokenizer_file - del_path.unlink() + del_path.unlink(missing_ok=True) @pytest.fixture(autouse=True) diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 45080be8e8ce..828d7eed309f 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -3,7 +3,8 @@ import asyncio import time import uuid -from typing import Dict, List +from contextlib import ExitStack +from typing import Dict, List, Optional import pytest from transformers import AutoTokenizer @@ -14,7 +15,9 @@ from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform from vllm.usage.usage_lib import UsageContext from vllm.v1.engine import EngineCoreRequest -from vllm.v1.engine.core_client import EngineCoreClient +from vllm.v1.engine.core import EngineCore +from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient, + SyncMPClient) from vllm.v1.executor.abstract import Executor if not current_platform.is_cuda(): @@ -63,7 +66,7 @@ def loop_until_done(client: EngineCoreClient, outputs: Dict): async def loop_until_done_async(client: EngineCoreClient, outputs: Dict): while True: - engine_core_outputs = await client.get_output_async().outputs + engine_core_outputs = (await client.get_output_async()).outputs if len(engine_core_outputs) == 0: break @@ -78,6 +81,14 @@ async def loop_until_done_async(client: EngineCoreClient, outputs: Dict): break +# Dummy utility function to monkey-patch into engine core. +def echo(self, msg: str, err_msg: Optional[str] = None) -> str: + print(f"echo util function called: {msg}, {err_msg}") + if err_msg is not None: + raise ValueError(err_msg) + return msg + + @fork_new_process_for_each_test @pytest.mark.parametrize("multiprocessing_mode", [True, False]) def test_engine_core_client(monkeypatch, multiprocessing_mode: bool): @@ -85,7 +96,10 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool): with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") - engine_args = EngineArgs(model=MODEL_NAME, compilation_config=3) + # Monkey-patch core engine utility function to test. + m.setattr(EngineCore, "echo", echo, raising=False) + + engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True) vllm_config = engine_args.create_engine_config( UsageContext.UNKNOWN_CONTEXT) executor_class = Executor.get_class(vllm_config) @@ -147,15 +161,30 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool): client.abort_requests([request.request_id]) + if multiprocessing_mode: + """Utility method invocation""" -@fork_new_process_for_each_test -@pytest.mark.asyncio + core_client: SyncMPClient = client + + result = core_client._call_utility("echo", "testarg") + assert result == "testarg" + + with pytest.raises(Exception) as e_info: + core_client._call_utility("echo", None, "help!") + + assert str(e_info.value) == "Call to echo method failed: help!" + + +@pytest.mark.asyncio(loop_scope="function") async def test_engine_core_client_asyncio(monkeypatch): - with monkeypatch.context() as m: + with monkeypatch.context() as m, ExitStack() as after: m.setenv("VLLM_USE_V1", "1") - engine_args = EngineArgs(model=MODEL_NAME) + # Monkey-patch core engine utility function to test. + m.setattr(EngineCore, "echo", echo, raising=False) + + engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True) vllm_config = engine_args.create_engine_config( usage_context=UsageContext.UNKNOWN_CONTEXT) executor_class = Executor.get_class(vllm_config) @@ -166,6 +195,7 @@ async def test_engine_core_client_asyncio(monkeypatch): executor_class=executor_class, log_stats=True, ) + after.callback(client.shutdown) MAX_TOKENS = 20 params = SamplingParams(max_tokens=MAX_TOKENS) @@ -204,3 +234,14 @@ async def test_engine_core_client_asyncio(monkeypatch): else: assert len(outputs[req_id]) == MAX_TOKENS, ( f"{len(outputs[req_id])=}, {MAX_TOKENS=}") + """Utility method invocation""" + + core_client: AsyncMPClient = client + + result = await core_client._call_utility_async("echo", "testarg") + assert result == "testarg" + + with pytest.raises(Exception) as e_info: + await core_client._call_utility_async("echo", None, "help!") + + assert str(e_info.value) == "Call to echo method failed: help!" diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index dee7102bb47b..7420dde1f7e4 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -2,7 +2,7 @@ import enum import time -from typing import List, Optional, Union +from typing import Any, List, Optional, Union import msgspec @@ -106,6 +106,18 @@ class EngineCoreOutput( return self.finish_reason is not None +class UtilityOutput( + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + gc=False): # type: ignore[call-arg] + + call_id: int + + # Non-None implies the call failed, result should be None. + failure_message: Optional[str] = None + result: Any = None + + class EngineCoreOutputs( msgspec.Struct, array_like=True, # type: ignore[call-arg] @@ -116,10 +128,12 @@ class EngineCoreOutputs( # e.g. columnwise layout # [num_reqs] - outputs: List[EngineCoreOutput] - scheduler_stats: Optional[SchedulerStats] + outputs: List[EngineCoreOutput] = [] + scheduler_stats: Optional[SchedulerStats] = None timestamp: float = 0.0 + utility_output: Optional[UtilityOutput] = None + def __post_init__(self): if self.timestamp == 0.0: self.timestamp = time.monotonic() @@ -132,6 +146,4 @@ class EngineCoreRequestType(enum.Enum): """ ADD = b'\x00' ABORT = b'\x01' - PROFILE = b'\x02' - RESET_PREFIX_CACHE = b'\x03' - ADD_LORA = b'\x04' + UTILITY = b'\x02' diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 6718a5f7b02d..66e252b7ccb0 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -5,9 +5,11 @@ import signal import threading import time from concurrent.futures import Future +from inspect import isclass, signature from multiprocessing.connection import Connection from typing import Any, List, Optional, Tuple, Type +import msgspec import psutil import zmq import zmq.asyncio @@ -21,7 +23,7 @@ from vllm.utils import get_exception_traceback, zmq_socket_ctx from vllm.v1.core.kv_cache_utils import get_kv_cache_configs from vllm.v1.core.scheduler import Scheduler, SchedulerOutput from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, - EngineCoreRequestType) + EngineCoreRequestType, UtilityOutput) from vllm.v1.engine.mm_input_cache import MMInputCacheServer from vllm.v1.executor.abstract import Executor from vllm.v1.outputs import ModelRunnerOutput @@ -330,19 +332,39 @@ class EngineCoreProc(EngineCore): self.add_request(request) elif request_type == EngineCoreRequestType.ABORT: self.abort_requests(request) - elif request_type == EngineCoreRequestType.RESET_PREFIX_CACHE: - self.reset_prefix_cache() - elif request_type == EngineCoreRequestType.PROFILE: - self.model_executor.profile(request) - elif request_type == EngineCoreRequestType.ADD_LORA: - self.model_executor.add_lora(request) + elif request_type == EngineCoreRequestType.UTILITY: + call_id, method_name, args = request + output = UtilityOutput(call_id) + try: + method = getattr(self, method_name) + output.result = method( + *self._convert_msgspec_args(method, args)) + except BaseException as e: + logger.exception("Invocation of %s method failed", method_name) + output.failure_message = (f"Call to {method_name} method" + f" failed: {str(e)}") + self.output_queue.put_nowait( + EngineCoreOutputs(utility_output=output)) + + @staticmethod + def _convert_msgspec_args(method, args): + """If a provided arg type doesn't match corresponding target method + arg type, try converting to msgspec object.""" + if not args: + return args + arg_types = signature(method).parameters.values() + assert len(args) <= len(arg_types) + return tuple( + msgspec.convert(v, type=p.annotation) if isclass(p.annotation) + and issubclass(p.annotation, msgspec.Struct) + and not isinstance(v, p.annotation) else v + for v, p in zip(args, arg_types)) def process_input_socket(self, input_path: str): """Input socket IO thread.""" # Msgpack serialization decoding. add_request_decoder = MsgpackDecoder(EngineCoreRequest) - add_lora_decoder = MsgpackDecoder(LoRARequest) generic_decoder = MsgpackDecoder() with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket: @@ -352,14 +374,9 @@ class EngineCoreProc(EngineCore): request_type = EngineCoreRequestType(bytes(type_frame.buffer)) # Deserialize the request data. - decoder = None - if request_type == EngineCoreRequestType.ADD: - decoder = add_request_decoder - elif request_type == EngineCoreRequestType.ADD_LORA: - decoder = add_lora_decoder - else: - decoder = generic_decoder - + decoder = add_request_decoder if ( + request_type + == EngineCoreRequestType.ADD) else generic_decoder request = decoder.decode(data_frame.buffer) # Push to input queue for core busy loop. diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 07176629e949..8641833e438b 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -2,10 +2,14 @@ import asyncio import os +import queue import signal +import uuid import weakref from abc import ABC, abstractmethod -from typing import Any, List, Optional, Type +from concurrent.futures import Future +from threading import Thread +from typing import Any, Dict, List, Optional, Type, Union import zmq import zmq.asyncio @@ -16,7 +20,7 @@ from vllm.lora.request import LoRARequest from vllm.utils import (get_open_zmq_ipc_path, kill_process_tree, make_zmq_socket) from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, - EngineCoreRequestType) + EngineCoreRequestType, UtilityOutput) from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.executor.abstract import Executor from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder @@ -24,6 +28,8 @@ from vllm.v1.utils import BackgroundProcHandle logger = init_logger(__name__) +AnyFuture = Union[asyncio.Future[Any], Future[Any]] + class EngineCoreClient(ABC): """ @@ -204,6 +210,8 @@ class MPClient(EngineCoreClient): "log_stats": log_stats, }) + self.utility_results: Dict[int, AnyFuture] = {} + def shutdown(self): """Clean up background resources.""" if hasattr(self, "proc_handle"): @@ -212,6 +220,16 @@ class MPClient(EngineCoreClient): self._finalizer() +def _process_utility_output(output: UtilityOutput, + utility_results: Dict[int, AnyFuture]): + """Set the result from a utility method in the waiting future""" + future = utility_results.pop(output.call_id) + if output.failure_message is not None: + future.set_exception(Exception(output.failure_message)) + else: + future.set_result(output.result) + + class SyncMPClient(MPClient): """Synchronous client for multi-proc EngineCore.""" @@ -224,10 +242,30 @@ class SyncMPClient(MPClient): log_stats=log_stats, ) - def get_output(self) -> EngineCoreOutputs: + self.outputs_queue: queue.Queue[EngineCoreOutputs] = queue.Queue() - (frame, ) = self.output_socket.recv_multipart(copy=False) - return self.decoder.decode(frame.buffer) + # Ensure that the outputs socket processing thread does not have + # a ref to the client which prevents gc. + output_socket = self.output_socket + decoder = self.decoder + utility_results = self.utility_results + outputs_queue = self.outputs_queue + + def process_outputs_socket(): + while True: + (frame, ) = output_socket.recv_multipart(copy=False) + outputs = decoder.decode(frame.buffer) + if outputs.utility_output: + _process_utility_output(outputs.utility_output, + utility_results) + else: + outputs_queue.put_nowait(outputs) + + # Process outputs from engine in separate thread. + Thread(target=process_outputs_socket, daemon=True).start() + + def get_output(self) -> EngineCoreOutputs: + return self.outputs_queue.get() def _send_input(self, request_type: EngineCoreRequestType, request: Any) -> None: @@ -236,6 +274,16 @@ class SyncMPClient(MPClient): msg = (request_type.value, self.encoder.encode(request)) self.input_socket.send_multipart(msg, copy=False) + def _call_utility(self, method: str, *args) -> Any: + call_id = uuid.uuid1().int >> 64 + future: Future[Any] = Future() + self.utility_results[call_id] = future + + self._send_input(EngineCoreRequestType.UTILITY, + (call_id, method, args)) + + return future.result() + def add_request(self, request: EngineCoreRequest) -> None: # NOTE: text prompt is not needed in the core engine as it has been # tokenized. @@ -247,13 +295,13 @@ class SyncMPClient(MPClient): self._send_input(EngineCoreRequestType.ABORT, request_ids) def profile(self, is_start: bool = True) -> None: - self._send_input(EngineCoreRequestType.PROFILE, is_start) + self._call_utility("profile", is_start) def reset_prefix_cache(self) -> None: - self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, None) + self._call_utility("reset_prefix_cache") def add_lora(self, lora_request: LoRARequest) -> None: - self._send_input(EngineCoreRequestType.ADD_LORA, lora_request) + self._call_utility("add_lora", lora_request) class AsyncMPClient(MPClient): @@ -268,24 +316,35 @@ class AsyncMPClient(MPClient): log_stats=log_stats, ) - self.outputs_queue: Optional[asyncio.Queue[bytes]] = None + self.outputs_queue: Optional[asyncio.Queue[EngineCoreOutputs]] = None self.queue_task: Optional[asyncio.Task] = None + async def _start_output_queue_task(self): + # Perform IO in separate task to parallelize as much as possible. + # Avoid task having direct reference back to the client. + self.outputs_queue = asyncio.Queue() + output_socket = self.output_socket + decoder = self.decoder + utility_results = self.utility_results + outputs_queue = self.outputs_queue + + async def process_outputs_socket(): + while True: + (frame, ) = await output_socket.recv_multipart(copy=False) + outputs: EngineCoreOutputs = decoder.decode(frame.buffer) + if outputs.utility_output: + _process_utility_output(outputs.utility_output, + utility_results) + else: + outputs_queue.put_nowait(outputs) + + self.queue_task = asyncio.create_task(process_outputs_socket()) + async def get_output_async(self) -> EngineCoreOutputs: if self.outputs_queue is None: - # Perform IO in separate task to parallelize as much as possible - self.outputs_queue = asyncio.Queue() - - async def process_outputs_socket(): - assert self.outputs_queue is not None - while True: - (frame, ) = await self.output_socket.recv_multipart( - copy=False) - self.outputs_queue.put_nowait(frame.buffer) - - self.queue_task = asyncio.create_task(process_outputs_socket()) - - return self.decoder.decode(await self.outputs_queue.get()) + await self._start_output_queue_task() + assert self.outputs_queue is not None + return await self.outputs_queue.get() async def _send_input(self, request_type: EngineCoreRequestType, request: Any) -> None: @@ -293,6 +352,18 @@ class AsyncMPClient(MPClient): msg = (request_type.value, self.encoder.encode(request)) await self.input_socket.send_multipart(msg, copy=False) + if self.outputs_queue is None: + await self._start_output_queue_task() + + async def _call_utility_async(self, method: str, *args) -> Any: + call_id = uuid.uuid1().int >> 64 + future = asyncio.get_running_loop().create_future() + self.utility_results[call_id] = future + await self._send_input(EngineCoreRequestType.UTILITY, + (call_id, method, args)) + + return await future + async def add_request_async(self, request: EngineCoreRequest) -> None: # NOTE: text prompt is not needed in the core engine as it has been # tokenized. @@ -304,10 +375,10 @@ class AsyncMPClient(MPClient): await self._send_input(EngineCoreRequestType.ABORT, request_ids) async def profile_async(self, is_start: bool = True) -> None: - await self._send_input(EngineCoreRequestType.PROFILE, is_start) + await self._call_utility_async("profile", is_start) async def reset_prefix_cache_async(self) -> None: - await self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, None) + await self._call_utility_async("reset_prefix_cache") async def add_lora_async(self, lora_request: LoRARequest) -> None: - await self._send_input(EngineCoreRequestType.ADD_LORA, lora_request) + await self._call_utility_async("add_lora", lora_request)