mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 22:55:51 +08:00
[Misc] Support more collective_rpc return types (#21845)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
8f0d516715
commit
56bd537dde
@ -6,8 +6,9 @@ import os
|
||||
import signal
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from threading import Thread
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
@ -292,6 +293,68 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
|
||||
client.shutdown()
|
||||
|
||||
|
||||
@dataclass
|
||||
class MyDataclass:
|
||||
message: str
|
||||
|
||||
|
||||
# Dummy utility function to monkey-patch into engine core.
|
||||
def echo_dc(
|
||||
self,
|
||||
msg: str,
|
||||
return_list: bool = False,
|
||||
) -> Union[MyDataclass, list[MyDataclass]]:
|
||||
print(f"echo dc util function called: {msg}")
|
||||
# Return dataclass to verify support for returning custom types
|
||||
# (for which there is special handling to make it work with msgspec).
|
||||
return [MyDataclass(msg) for _ in range(3)] if return_list \
|
||||
else MyDataclass(msg)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="function")
|
||||
async def test_engine_core_client_util_method_custom_return(
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
# Must set insecure serialization to allow returning custom types.
|
||||
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
# Monkey-patch core engine utility function to test.
|
||||
m.setattr(EngineCore, "echo_dc", echo_dc, raising=False)
|
||||
|
||||
engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
|
||||
vllm_config = engine_args.create_engine_config(
|
||||
usage_context=UsageContext.UNKNOWN_CONTEXT)
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
|
||||
with set_default_torch_num_threads(1):
|
||||
client = EngineCoreClient.make_client(
|
||||
multiprocess_mode=True,
|
||||
asyncio_mode=True,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=True,
|
||||
)
|
||||
|
||||
try:
|
||||
# Test utility method returning custom / non-native data type.
|
||||
core_client: AsyncMPClient = client
|
||||
|
||||
result = await core_client.call_utility_async(
|
||||
"echo_dc", "testarg2", False)
|
||||
assert isinstance(result,
|
||||
MyDataclass) and result.message == "testarg2"
|
||||
result = await core_client.call_utility_async(
|
||||
"echo_dc", "testarg2", True)
|
||||
assert isinstance(result, list) and all(
|
||||
isinstance(r, MyDataclass) and r.message == "testarg2"
|
||||
for r in result)
|
||||
finally:
|
||||
client.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"multiprocessing_mode,publisher_config",
|
||||
[(True, "tcp"), (False, "inproc")],
|
||||
|
||||
@ -123,6 +123,13 @@ class EngineCoreOutput(
|
||||
return self.finish_reason is not None
|
||||
|
||||
|
||||
class UtilityResult:
|
||||
"""Wrapper for special handling when serializing/deserializing."""
|
||||
|
||||
def __init__(self, r: Any = None):
|
||||
self.result = r
|
||||
|
||||
|
||||
class UtilityOutput(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
@ -132,7 +139,7 @@ class UtilityOutput(
|
||||
|
||||
# Non-None implies the call failed, result should be None.
|
||||
failure_message: Optional[str] = None
|
||||
result: Any = None
|
||||
result: Optional[UtilityResult] = None
|
||||
|
||||
|
||||
class EngineCoreOutputs(
|
||||
|
||||
@ -36,7 +36,7 @@ from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
|
||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
||||
EngineCoreRequestType,
|
||||
ReconfigureDistributedRequest, ReconfigureRankType,
|
||||
UtilityOutput)
|
||||
UtilityOutput, UtilityResult)
|
||||
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
|
||||
from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
@ -715,8 +715,8 @@ class EngineCoreProc(EngineCore):
|
||||
output = UtilityOutput(call_id)
|
||||
try:
|
||||
method = getattr(self, method_name)
|
||||
output.result = method(
|
||||
*self._convert_msgspec_args(method, args))
|
||||
result = method(*self._convert_msgspec_args(method, args))
|
||||
output.result = UtilityResult(result)
|
||||
except BaseException as e:
|
||||
logger.exception("Invocation of %s method failed", method_name)
|
||||
output.failure_message = (f"Call to {method_name} method"
|
||||
|
||||
@ -552,7 +552,8 @@ def _process_utility_output(output: UtilityOutput,
|
||||
if output.failure_message is not None:
|
||||
future.set_exception(Exception(output.failure_message))
|
||||
else:
|
||||
future.set_result(output.result)
|
||||
assert output.result is not None
|
||||
future.set_result(output.result.result)
|
||||
|
||||
|
||||
class SyncMPClient(MPClient):
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
import importlib
|
||||
import pickle
|
||||
from collections.abc import Sequence
|
||||
from inspect import isclass
|
||||
@ -9,6 +10,7 @@ from types import FunctionType
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import cloudpickle
|
||||
import msgspec
|
||||
import numpy as np
|
||||
import torch
|
||||
import zmq
|
||||
@ -22,6 +24,7 @@ from vllm.multimodal.inputs import (BaseMultiModalField,
|
||||
MultiModalFlatField, MultiModalKwargs,
|
||||
MultiModalKwargsItem,
|
||||
MultiModalSharedField, NestedTensors)
|
||||
from vllm.v1.engine import UtilityResult
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -46,6 +49,10 @@ def _log_insecure_serialization_warning():
|
||||
"VLLM_ALLOW_INSECURE_SERIALIZATION=1")
|
||||
|
||||
|
||||
def _typestr(t: type):
|
||||
return t.__module__, t.__qualname__
|
||||
|
||||
|
||||
class MsgpackEncoder:
|
||||
"""Encoder with custom torch tensor and numpy array serialization.
|
||||
|
||||
@ -122,6 +129,18 @@ class MsgpackEncoder:
|
||||
for itemlist in mm._items_by_modality.values()
|
||||
for item in itemlist]
|
||||
|
||||
if isinstance(obj, UtilityResult):
|
||||
result = obj.result
|
||||
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION or result is None:
|
||||
return None, result
|
||||
# Since utility results are not strongly typed, we also encode
|
||||
# the type (or a list of types in the case it's a list) to
|
||||
# help with correct msgspec deserialization.
|
||||
cls = result.__class__
|
||||
return _typestr(cls) if cls is not list else [
|
||||
_typestr(type(v)) for v in result
|
||||
], result
|
||||
|
||||
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
|
||||
raise TypeError(f"Object of type {type(obj)} is not serializable"
|
||||
"Set VLLM_ALLOW_INSECURE_SERIALIZATION=1 to allow "
|
||||
@ -237,8 +256,33 @@ class MsgpackDecoder:
|
||||
k: self._decode_nested_tensors(v)
|
||||
for k, v in obj.items()
|
||||
})
|
||||
if t is UtilityResult:
|
||||
return self._decode_utility_result(obj)
|
||||
return obj
|
||||
|
||||
def _decode_utility_result(self, obj: Any) -> UtilityResult:
|
||||
result_type, result = obj
|
||||
if result_type is not None:
|
||||
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
|
||||
raise TypeError("VLLM_ALLOW_INSECURE_SERIALIZATION must "
|
||||
"be set to use custom utility result types")
|
||||
assert isinstance(result_type, list)
|
||||
if len(result_type) == 2 and isinstance(result_type[0], str):
|
||||
result = self._convert_result(result_type, result)
|
||||
else:
|
||||
assert isinstance(result, list)
|
||||
result = [
|
||||
self._convert_result(rt, r)
|
||||
for rt, r in zip(result_type, result)
|
||||
]
|
||||
return UtilityResult(result)
|
||||
|
||||
def _convert_result(self, result_type: Sequence[str], result: Any):
|
||||
mod_name, name = result_type
|
||||
mod = importlib.import_module(mod_name)
|
||||
result_type = getattr(mod, name)
|
||||
return msgspec.convert(result, result_type, dec_hook=self.dec_hook)
|
||||
|
||||
def _decode_ndarray(self, arr: Any) -> np.ndarray:
|
||||
dtype, shape, data = arr
|
||||
# zero-copy decode. We assume the ndarray will not be kept around,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user