diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 2ac6dc796bd10..f648c38a63f79 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -6,8 +6,9 @@ import os import signal import time import uuid +from dataclasses import dataclass from threading import Thread -from typing import Optional +from typing import Optional, Union from unittest.mock import MagicMock import pytest @@ -292,6 +293,68 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch): client.shutdown() +@dataclass +class MyDataclass: + message: str + + +# Dummy utility function to monkey-patch into engine core. +def echo_dc( + self, + msg: str, + return_list: bool = False, +) -> Union[MyDataclass, list[MyDataclass]]: + print(f"echo dc util function called: {msg}") + # Return dataclass to verify support for returning custom types + # (for which there is special handling to make it work with msgspec). + return [MyDataclass(msg) for _ in range(3)] if return_list \ + else MyDataclass(msg) + + +@pytest.mark.asyncio(loop_scope="function") +async def test_engine_core_client_util_method_custom_return( + monkeypatch: pytest.MonkeyPatch): + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + # Must set insecure serialization to allow returning custom types. + m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + + # Monkey-patch core engine utility function to test. + m.setattr(EngineCore, "echo_dc", echo_dc, raising=False) + + engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True) + vllm_config = engine_args.create_engine_config( + usage_context=UsageContext.UNKNOWN_CONTEXT) + executor_class = Executor.get_class(vllm_config) + + with set_default_torch_num_threads(1): + client = EngineCoreClient.make_client( + multiprocess_mode=True, + asyncio_mode=True, + vllm_config=vllm_config, + executor_class=executor_class, + log_stats=True, + ) + + try: + # Test utility method returning custom / non-native data type. + core_client: AsyncMPClient = client + + result = await core_client.call_utility_async( + "echo_dc", "testarg2", False) + assert isinstance(result, + MyDataclass) and result.message == "testarg2" + result = await core_client.call_utility_async( + "echo_dc", "testarg2", True) + assert isinstance(result, list) and all( + isinstance(r, MyDataclass) and r.message == "testarg2" + for r in result) + finally: + client.shutdown() + + @pytest.mark.parametrize( "multiprocessing_mode,publisher_config", [(True, "tcp"), (False, "inproc")], diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 79dc80d8fc547..810d03f32d726 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -123,6 +123,13 @@ class EngineCoreOutput( return self.finish_reason is not None +class UtilityResult: + """Wrapper for special handling when serializing/deserializing.""" + + def __init__(self, r: Any = None): + self.result = r + + class UtilityOutput( msgspec.Struct, array_like=True, # type: ignore[call-arg] @@ -132,7 +139,7 @@ class UtilityOutput( # Non-None implies the call failed, result should be None. failure_message: Optional[str] = None - result: Any = None + result: Optional[UtilityResult] = None class EngineCoreOutputs( diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 39fda521f36af..9f2fca6961388 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -36,7 +36,7 @@ from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, ReconfigureDistributedRequest, ReconfigureRankType, - UtilityOutput) + UtilityOutput, UtilityResult) from vllm.v1.engine.mm_input_cache import MirroredProcessingCache from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses from vllm.v1.executor.abstract import Executor @@ -715,8 +715,8 @@ class EngineCoreProc(EngineCore): output = UtilityOutput(call_id) try: method = getattr(self, method_name) - output.result = method( - *self._convert_msgspec_args(method, args)) + result = method(*self._convert_msgspec_args(method, args)) + output.result = UtilityResult(result) except BaseException as e: logger.exception("Invocation of %s method failed", method_name) output.failure_message = (f"Call to {method_name} method" diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index acff5bf6823d9..fdf5a5de191c0 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -552,7 +552,8 @@ def _process_utility_output(output: UtilityOutput, if output.failure_message is not None: future.set_exception(Exception(output.failure_message)) else: - future.set_result(output.result) + assert output.result is not None + future.set_result(output.result.result) class SyncMPClient(MPClient): diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 03200c2c2f8ec..4b6a983252b0e 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import dataclasses +import importlib import pickle from collections.abc import Sequence from inspect import isclass @@ -9,6 +10,7 @@ from types import FunctionType from typing import Any, Optional, Union import cloudpickle +import msgspec import numpy as np import torch import zmq @@ -22,6 +24,7 @@ from vllm.multimodal.inputs import (BaseMultiModalField, MultiModalFlatField, MultiModalKwargs, MultiModalKwargsItem, MultiModalSharedField, NestedTensors) +from vllm.v1.engine import UtilityResult logger = init_logger(__name__) @@ -46,6 +49,10 @@ def _log_insecure_serialization_warning(): "VLLM_ALLOW_INSECURE_SERIALIZATION=1") +def _typestr(t: type): + return t.__module__, t.__qualname__ + + class MsgpackEncoder: """Encoder with custom torch tensor and numpy array serialization. @@ -122,6 +129,18 @@ class MsgpackEncoder: for itemlist in mm._items_by_modality.values() for item in itemlist] + if isinstance(obj, UtilityResult): + result = obj.result + if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION or result is None: + return None, result + # Since utility results are not strongly typed, we also encode + # the type (or a list of types in the case it's a list) to + # help with correct msgspec deserialization. + cls = result.__class__ + return _typestr(cls) if cls is not list else [ + _typestr(type(v)) for v in result + ], result + if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION: raise TypeError(f"Object of type {type(obj)} is not serializable" "Set VLLM_ALLOW_INSECURE_SERIALIZATION=1 to allow " @@ -237,8 +256,33 @@ class MsgpackDecoder: k: self._decode_nested_tensors(v) for k, v in obj.items() }) + if t is UtilityResult: + return self._decode_utility_result(obj) return obj + def _decode_utility_result(self, obj: Any) -> UtilityResult: + result_type, result = obj + if result_type is not None: + if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION: + raise TypeError("VLLM_ALLOW_INSECURE_SERIALIZATION must " + "be set to use custom utility result types") + assert isinstance(result_type, list) + if len(result_type) == 2 and isinstance(result_type[0], str): + result = self._convert_result(result_type, result) + else: + assert isinstance(result, list) + result = [ + self._convert_result(rt, r) + for rt, r in zip(result_type, result) + ] + return UtilityResult(result) + + def _convert_result(self, result_type: Sequence[str], result: Any): + mod_name, name = result_type + mod = importlib.import_module(mod_name) + result_type = getattr(mod, name) + return msgspec.convert(result, result_type, dec_hook=self.dec_hook) + def _decode_ndarray(self, arr: Any) -> np.ndarray: dtype, shape, data = arr # zero-copy decode. We assume the ndarray will not be kept around,