[V1][Core] Generic mechanism for handling engine utility (#13060)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-02-19 01:09:22 -08:00 committed by GitHub
parent f525c0be8b
commit caf7ff4456
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 197 additions and 56 deletions

View File

@ -41,7 +41,7 @@ def download_and_prepare_lora_module():
] ]
for tokenizer_file in tokenizer_files: for tokenizer_file in tokenizer_files:
del_path = Path(LORA_MODULE_DOWNLOAD_PATH) / tokenizer_file del_path = Path(LORA_MODULE_DOWNLOAD_PATH) / tokenizer_file
del_path.unlink() del_path.unlink(missing_ok=True)
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)

View File

@ -3,7 +3,8 @@
import asyncio import asyncio
import time import time
import uuid import uuid
from typing import Dict, List from contextlib import ExitStack
from typing import Dict, List, Optional
import pytest import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
@ -14,7 +15,9 @@ from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.core import EngineCore
from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient,
SyncMPClient)
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
if not current_platform.is_cuda(): if not current_platform.is_cuda():
@ -63,7 +66,7 @@ def loop_until_done(client: EngineCoreClient, outputs: Dict):
async def loop_until_done_async(client: EngineCoreClient, outputs: Dict): async def loop_until_done_async(client: EngineCoreClient, outputs: Dict):
while True: while True:
engine_core_outputs = await client.get_output_async().outputs engine_core_outputs = (await client.get_output_async()).outputs
if len(engine_core_outputs) == 0: if len(engine_core_outputs) == 0:
break break
@ -78,6 +81,14 @@ async def loop_until_done_async(client: EngineCoreClient, outputs: Dict):
break break
# Dummy utility function to monkey-patch into engine core.
def echo(self, msg: str, err_msg: Optional[str] = None) -> str:
print(f"echo util function called: {msg}, {err_msg}")
if err_msg is not None:
raise ValueError(err_msg)
return msg
@fork_new_process_for_each_test @fork_new_process_for_each_test
@pytest.mark.parametrize("multiprocessing_mode", [True, False]) @pytest.mark.parametrize("multiprocessing_mode", [True, False])
def test_engine_core_client(monkeypatch, multiprocessing_mode: bool): def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
@ -85,7 +96,10 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
engine_args = EngineArgs(model=MODEL_NAME, compilation_config=3) # Monkey-patch core engine utility function to test.
m.setattr(EngineCore, "echo", echo, raising=False)
engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
vllm_config = engine_args.create_engine_config( vllm_config = engine_args.create_engine_config(
UsageContext.UNKNOWN_CONTEXT) UsageContext.UNKNOWN_CONTEXT)
executor_class = Executor.get_class(vllm_config) executor_class = Executor.get_class(vllm_config)
@ -147,15 +161,30 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
client.abort_requests([request.request_id]) client.abort_requests([request.request_id])
if multiprocessing_mode:
"""Utility method invocation"""
@fork_new_process_for_each_test core_client: SyncMPClient = client
@pytest.mark.asyncio
result = core_client._call_utility("echo", "testarg")
assert result == "testarg"
with pytest.raises(Exception) as e_info:
core_client._call_utility("echo", None, "help!")
assert str(e_info.value) == "Call to echo method failed: help!"
@pytest.mark.asyncio(loop_scope="function")
async def test_engine_core_client_asyncio(monkeypatch): async def test_engine_core_client_asyncio(monkeypatch):
with monkeypatch.context() as m: with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
engine_args = EngineArgs(model=MODEL_NAME) # Monkey-patch core engine utility function to test.
m.setattr(EngineCore, "echo", echo, raising=False)
engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
vllm_config = engine_args.create_engine_config( vllm_config = engine_args.create_engine_config(
usage_context=UsageContext.UNKNOWN_CONTEXT) usage_context=UsageContext.UNKNOWN_CONTEXT)
executor_class = Executor.get_class(vllm_config) executor_class = Executor.get_class(vllm_config)
@ -166,6 +195,7 @@ async def test_engine_core_client_asyncio(monkeypatch):
executor_class=executor_class, executor_class=executor_class,
log_stats=True, log_stats=True,
) )
after.callback(client.shutdown)
MAX_TOKENS = 20 MAX_TOKENS = 20
params = SamplingParams(max_tokens=MAX_TOKENS) params = SamplingParams(max_tokens=MAX_TOKENS)
@ -204,3 +234,14 @@ async def test_engine_core_client_asyncio(monkeypatch):
else: else:
assert len(outputs[req_id]) == MAX_TOKENS, ( assert len(outputs[req_id]) == MAX_TOKENS, (
f"{len(outputs[req_id])=}, {MAX_TOKENS=}") f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
"""Utility method invocation"""
core_client: AsyncMPClient = client
result = await core_client._call_utility_async("echo", "testarg")
assert result == "testarg"
with pytest.raises(Exception) as e_info:
await core_client._call_utility_async("echo", None, "help!")
assert str(e_info.value) == "Call to echo method failed: help!"

View File

@ -2,7 +2,7 @@
import enum import enum
import time import time
from typing import List, Optional, Union from typing import Any, List, Optional, Union
import msgspec import msgspec
@ -106,6 +106,18 @@ class EngineCoreOutput(
return self.finish_reason is not None return self.finish_reason is not None
class UtilityOutput(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
gc=False): # type: ignore[call-arg]
call_id: int
# Non-None implies the call failed, result should be None.
failure_message: Optional[str] = None
result: Any = None
class EngineCoreOutputs( class EngineCoreOutputs(
msgspec.Struct, msgspec.Struct,
array_like=True, # type: ignore[call-arg] array_like=True, # type: ignore[call-arg]
@ -116,10 +128,12 @@ class EngineCoreOutputs(
# e.g. columnwise layout # e.g. columnwise layout
# [num_reqs] # [num_reqs]
outputs: List[EngineCoreOutput] outputs: List[EngineCoreOutput] = []
scheduler_stats: Optional[SchedulerStats] scheduler_stats: Optional[SchedulerStats] = None
timestamp: float = 0.0 timestamp: float = 0.0
utility_output: Optional[UtilityOutput] = None
def __post_init__(self): def __post_init__(self):
if self.timestamp == 0.0: if self.timestamp == 0.0:
self.timestamp = time.monotonic() self.timestamp = time.monotonic()
@ -132,6 +146,4 @@ class EngineCoreRequestType(enum.Enum):
""" """
ADD = b'\x00' ADD = b'\x00'
ABORT = b'\x01' ABORT = b'\x01'
PROFILE = b'\x02' UTILITY = b'\x02'
RESET_PREFIX_CACHE = b'\x03'
ADD_LORA = b'\x04'

View File

@ -5,9 +5,11 @@ import signal
import threading import threading
import time import time
from concurrent.futures import Future from concurrent.futures import Future
from inspect import isclass, signature
from multiprocessing.connection import Connection from multiprocessing.connection import Connection
from typing import Any, List, Optional, Tuple, Type from typing import Any, List, Optional, Tuple, Type
import msgspec
import psutil import psutil
import zmq import zmq
import zmq.asyncio import zmq.asyncio
@ -21,7 +23,7 @@ from vllm.utils import get_exception_traceback, zmq_socket_ctx
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
from vllm.v1.core.scheduler import Scheduler, SchedulerOutput from vllm.v1.core.scheduler import Scheduler, SchedulerOutput
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType) EngineCoreRequestType, UtilityOutput)
from vllm.v1.engine.mm_input_cache import MMInputCacheServer from vllm.v1.engine.mm_input_cache import MMInputCacheServer
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
@ -330,19 +332,39 @@ class EngineCoreProc(EngineCore):
self.add_request(request) self.add_request(request)
elif request_type == EngineCoreRequestType.ABORT: elif request_type == EngineCoreRequestType.ABORT:
self.abort_requests(request) self.abort_requests(request)
elif request_type == EngineCoreRequestType.RESET_PREFIX_CACHE: elif request_type == EngineCoreRequestType.UTILITY:
self.reset_prefix_cache() call_id, method_name, args = request
elif request_type == EngineCoreRequestType.PROFILE: output = UtilityOutput(call_id)
self.model_executor.profile(request) try:
elif request_type == EngineCoreRequestType.ADD_LORA: method = getattr(self, method_name)
self.model_executor.add_lora(request) output.result = method(
*self._convert_msgspec_args(method, args))
except BaseException as e:
logger.exception("Invocation of %s method failed", method_name)
output.failure_message = (f"Call to {method_name} method"
f" failed: {str(e)}")
self.output_queue.put_nowait(
EngineCoreOutputs(utility_output=output))
@staticmethod
def _convert_msgspec_args(method, args):
"""If a provided arg type doesn't match corresponding target method
arg type, try converting to msgspec object."""
if not args:
return args
arg_types = signature(method).parameters.values()
assert len(args) <= len(arg_types)
return tuple(
msgspec.convert(v, type=p.annotation) if isclass(p.annotation)
and issubclass(p.annotation, msgspec.Struct)
and not isinstance(v, p.annotation) else v
for v, p in zip(args, arg_types))
def process_input_socket(self, input_path: str): def process_input_socket(self, input_path: str):
"""Input socket IO thread.""" """Input socket IO thread."""
# Msgpack serialization decoding. # Msgpack serialization decoding.
add_request_decoder = MsgpackDecoder(EngineCoreRequest) add_request_decoder = MsgpackDecoder(EngineCoreRequest)
add_lora_decoder = MsgpackDecoder(LoRARequest)
generic_decoder = MsgpackDecoder() generic_decoder = MsgpackDecoder()
with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket: with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
@ -352,14 +374,9 @@ class EngineCoreProc(EngineCore):
request_type = EngineCoreRequestType(bytes(type_frame.buffer)) request_type = EngineCoreRequestType(bytes(type_frame.buffer))
# Deserialize the request data. # Deserialize the request data.
decoder = None decoder = add_request_decoder if (
if request_type == EngineCoreRequestType.ADD: request_type
decoder = add_request_decoder == EngineCoreRequestType.ADD) else generic_decoder
elif request_type == EngineCoreRequestType.ADD_LORA:
decoder = add_lora_decoder
else:
decoder = generic_decoder
request = decoder.decode(data_frame.buffer) request = decoder.decode(data_frame.buffer)
# Push to input queue for core busy loop. # Push to input queue for core busy loop.

View File

@ -2,10 +2,14 @@
import asyncio import asyncio
import os import os
import queue
import signal import signal
import uuid
import weakref import weakref
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, List, Optional, Type from concurrent.futures import Future
from threading import Thread
from typing import Any, Dict, List, Optional, Type, Union
import zmq import zmq
import zmq.asyncio import zmq.asyncio
@ -16,7 +20,7 @@ from vllm.lora.request import LoRARequest
from vllm.utils import (get_open_zmq_ipc_path, kill_process_tree, from vllm.utils import (get_open_zmq_ipc_path, kill_process_tree,
make_zmq_socket) make_zmq_socket)
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType) EngineCoreRequestType, UtilityOutput)
from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.engine.core import EngineCore, EngineCoreProc
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
@ -24,6 +28,8 @@ from vllm.v1.utils import BackgroundProcHandle
logger = init_logger(__name__) logger = init_logger(__name__)
AnyFuture = Union[asyncio.Future[Any], Future[Any]]
class EngineCoreClient(ABC): class EngineCoreClient(ABC):
""" """
@ -204,6 +210,8 @@ class MPClient(EngineCoreClient):
"log_stats": log_stats, "log_stats": log_stats,
}) })
self.utility_results: Dict[int, AnyFuture] = {}
def shutdown(self): def shutdown(self):
"""Clean up background resources.""" """Clean up background resources."""
if hasattr(self, "proc_handle"): if hasattr(self, "proc_handle"):
@ -212,6 +220,16 @@ class MPClient(EngineCoreClient):
self._finalizer() self._finalizer()
def _process_utility_output(output: UtilityOutput,
utility_results: Dict[int, AnyFuture]):
"""Set the result from a utility method in the waiting future"""
future = utility_results.pop(output.call_id)
if output.failure_message is not None:
future.set_exception(Exception(output.failure_message))
else:
future.set_result(output.result)
class SyncMPClient(MPClient): class SyncMPClient(MPClient):
"""Synchronous client for multi-proc EngineCore.""" """Synchronous client for multi-proc EngineCore."""
@ -224,10 +242,30 @@ class SyncMPClient(MPClient):
log_stats=log_stats, log_stats=log_stats,
) )
def get_output(self) -> EngineCoreOutputs: self.outputs_queue: queue.Queue[EngineCoreOutputs] = queue.Queue()
(frame, ) = self.output_socket.recv_multipart(copy=False) # Ensure that the outputs socket processing thread does not have
return self.decoder.decode(frame.buffer) # a ref to the client which prevents gc.
output_socket = self.output_socket
decoder = self.decoder
utility_results = self.utility_results
outputs_queue = self.outputs_queue
def process_outputs_socket():
while True:
(frame, ) = output_socket.recv_multipart(copy=False)
outputs = decoder.decode(frame.buffer)
if outputs.utility_output:
_process_utility_output(outputs.utility_output,
utility_results)
else:
outputs_queue.put_nowait(outputs)
# Process outputs from engine in separate thread.
Thread(target=process_outputs_socket, daemon=True).start()
def get_output(self) -> EngineCoreOutputs:
return self.outputs_queue.get()
def _send_input(self, request_type: EngineCoreRequestType, def _send_input(self, request_type: EngineCoreRequestType,
request: Any) -> None: request: Any) -> None:
@ -236,6 +274,16 @@ class SyncMPClient(MPClient):
msg = (request_type.value, self.encoder.encode(request)) msg = (request_type.value, self.encoder.encode(request))
self.input_socket.send_multipart(msg, copy=False) self.input_socket.send_multipart(msg, copy=False)
def _call_utility(self, method: str, *args) -> Any:
call_id = uuid.uuid1().int >> 64
future: Future[Any] = Future()
self.utility_results[call_id] = future
self._send_input(EngineCoreRequestType.UTILITY,
(call_id, method, args))
return future.result()
def add_request(self, request: EngineCoreRequest) -> None: def add_request(self, request: EngineCoreRequest) -> None:
# NOTE: text prompt is not needed in the core engine as it has been # NOTE: text prompt is not needed in the core engine as it has been
# tokenized. # tokenized.
@ -247,13 +295,13 @@ class SyncMPClient(MPClient):
self._send_input(EngineCoreRequestType.ABORT, request_ids) self._send_input(EngineCoreRequestType.ABORT, request_ids)
def profile(self, is_start: bool = True) -> None: def profile(self, is_start: bool = True) -> None:
self._send_input(EngineCoreRequestType.PROFILE, is_start) self._call_utility("profile", is_start)
def reset_prefix_cache(self) -> None: def reset_prefix_cache(self) -> None:
self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, None) self._call_utility("reset_prefix_cache")
def add_lora(self, lora_request: LoRARequest) -> None: def add_lora(self, lora_request: LoRARequest) -> None:
self._send_input(EngineCoreRequestType.ADD_LORA, lora_request) self._call_utility("add_lora", lora_request)
class AsyncMPClient(MPClient): class AsyncMPClient(MPClient):
@ -268,24 +316,35 @@ class AsyncMPClient(MPClient):
log_stats=log_stats, log_stats=log_stats,
) )
self.outputs_queue: Optional[asyncio.Queue[bytes]] = None self.outputs_queue: Optional[asyncio.Queue[EngineCoreOutputs]] = None
self.queue_task: Optional[asyncio.Task] = None self.queue_task: Optional[asyncio.Task] = None
async def _start_output_queue_task(self):
# Perform IO in separate task to parallelize as much as possible.
# Avoid task having direct reference back to the client.
self.outputs_queue = asyncio.Queue()
output_socket = self.output_socket
decoder = self.decoder
utility_results = self.utility_results
outputs_queue = self.outputs_queue
async def process_outputs_socket():
while True:
(frame, ) = await output_socket.recv_multipart(copy=False)
outputs: EngineCoreOutputs = decoder.decode(frame.buffer)
if outputs.utility_output:
_process_utility_output(outputs.utility_output,
utility_results)
else:
outputs_queue.put_nowait(outputs)
self.queue_task = asyncio.create_task(process_outputs_socket())
async def get_output_async(self) -> EngineCoreOutputs: async def get_output_async(self) -> EngineCoreOutputs:
if self.outputs_queue is None: if self.outputs_queue is None:
# Perform IO in separate task to parallelize as much as possible await self._start_output_queue_task()
self.outputs_queue = asyncio.Queue() assert self.outputs_queue is not None
return await self.outputs_queue.get()
async def process_outputs_socket():
assert self.outputs_queue is not None
while True:
(frame, ) = await self.output_socket.recv_multipart(
copy=False)
self.outputs_queue.put_nowait(frame.buffer)
self.queue_task = asyncio.create_task(process_outputs_socket())
return self.decoder.decode(await self.outputs_queue.get())
async def _send_input(self, request_type: EngineCoreRequestType, async def _send_input(self, request_type: EngineCoreRequestType,
request: Any) -> None: request: Any) -> None:
@ -293,6 +352,18 @@ class AsyncMPClient(MPClient):
msg = (request_type.value, self.encoder.encode(request)) msg = (request_type.value, self.encoder.encode(request))
await self.input_socket.send_multipart(msg, copy=False) await self.input_socket.send_multipart(msg, copy=False)
if self.outputs_queue is None:
await self._start_output_queue_task()
async def _call_utility_async(self, method: str, *args) -> Any:
call_id = uuid.uuid1().int >> 64
future = asyncio.get_running_loop().create_future()
self.utility_results[call_id] = future
await self._send_input(EngineCoreRequestType.UTILITY,
(call_id, method, args))
return await future
async def add_request_async(self, request: EngineCoreRequest) -> None: async def add_request_async(self, request: EngineCoreRequest) -> None:
# NOTE: text prompt is not needed in the core engine as it has been # NOTE: text prompt is not needed in the core engine as it has been
# tokenized. # tokenized.
@ -304,10 +375,10 @@ class AsyncMPClient(MPClient):
await self._send_input(EngineCoreRequestType.ABORT, request_ids) await self._send_input(EngineCoreRequestType.ABORT, request_ids)
async def profile_async(self, is_start: bool = True) -> None: async def profile_async(self, is_start: bool = True) -> None:
await self._send_input(EngineCoreRequestType.PROFILE, is_start) await self._call_utility_async("profile", is_start)
async def reset_prefix_cache_async(self) -> None: async def reset_prefix_cache_async(self) -> None:
await self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, None) await self._call_utility_async("reset_prefix_cache")
async def add_lora_async(self, lora_request: LoRARequest) -> None: async def add_lora_async(self, lora_request: LoRARequest) -> None:
await self._send_input(EngineCoreRequestType.ADD_LORA, lora_request) await self._call_utility_async("add_lora", lora_request)