mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-20 07:05:01 +08:00
[Misc] Support more collective_rpc return types (#21845)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
8f0d516715
commit
56bd537dde
@ -6,8 +6,9 @@ import os
|
|||||||
import signal
|
import signal
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
from dataclasses import dataclass
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -292,6 +293,68 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
|
|||||||
client.shutdown()
|
client.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MyDataclass:
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
# Dummy utility function to monkey-patch into engine core.
|
||||||
|
def echo_dc(
|
||||||
|
self,
|
||||||
|
msg: str,
|
||||||
|
return_list: bool = False,
|
||||||
|
) -> Union[MyDataclass, list[MyDataclass]]:
|
||||||
|
print(f"echo dc util function called: {msg}")
|
||||||
|
# Return dataclass to verify support for returning custom types
|
||||||
|
# (for which there is special handling to make it work with msgspec).
|
||||||
|
return [MyDataclass(msg) for _ in range(3)] if return_list \
|
||||||
|
else MyDataclass(msg)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="function")
|
||||||
|
async def test_engine_core_client_util_method_custom_return(
|
||||||
|
monkeypatch: pytest.MonkeyPatch):
|
||||||
|
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
|
||||||
|
# Must set insecure serialization to allow returning custom types.
|
||||||
|
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||||
|
|
||||||
|
# Monkey-patch core engine utility function to test.
|
||||||
|
m.setattr(EngineCore, "echo_dc", echo_dc, raising=False)
|
||||||
|
|
||||||
|
engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
|
||||||
|
vllm_config = engine_args.create_engine_config(
|
||||||
|
usage_context=UsageContext.UNKNOWN_CONTEXT)
|
||||||
|
executor_class = Executor.get_class(vllm_config)
|
||||||
|
|
||||||
|
with set_default_torch_num_threads(1):
|
||||||
|
client = EngineCoreClient.make_client(
|
||||||
|
multiprocess_mode=True,
|
||||||
|
asyncio_mode=True,
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
executor_class=executor_class,
|
||||||
|
log_stats=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Test utility method returning custom / non-native data type.
|
||||||
|
core_client: AsyncMPClient = client
|
||||||
|
|
||||||
|
result = await core_client.call_utility_async(
|
||||||
|
"echo_dc", "testarg2", False)
|
||||||
|
assert isinstance(result,
|
||||||
|
MyDataclass) and result.message == "testarg2"
|
||||||
|
result = await core_client.call_utility_async(
|
||||||
|
"echo_dc", "testarg2", True)
|
||||||
|
assert isinstance(result, list) and all(
|
||||||
|
isinstance(r, MyDataclass) and r.message == "testarg2"
|
||||||
|
for r in result)
|
||||||
|
finally:
|
||||||
|
client.shutdown()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"multiprocessing_mode,publisher_config",
|
"multiprocessing_mode,publisher_config",
|
||||||
[(True, "tcp"), (False, "inproc")],
|
[(True, "tcp"), (False, "inproc")],
|
||||||
|
|||||||
@ -123,6 +123,13 @@ class EngineCoreOutput(
|
|||||||
return self.finish_reason is not None
|
return self.finish_reason is not None
|
||||||
|
|
||||||
|
|
||||||
|
class UtilityResult:
|
||||||
|
"""Wrapper for special handling when serializing/deserializing."""
|
||||||
|
|
||||||
|
def __init__(self, r: Any = None):
|
||||||
|
self.result = r
|
||||||
|
|
||||||
|
|
||||||
class UtilityOutput(
|
class UtilityOutput(
|
||||||
msgspec.Struct,
|
msgspec.Struct,
|
||||||
array_like=True, # type: ignore[call-arg]
|
array_like=True, # type: ignore[call-arg]
|
||||||
@ -132,7 +139,7 @@ class UtilityOutput(
|
|||||||
|
|
||||||
# Non-None implies the call failed, result should be None.
|
# Non-None implies the call failed, result should be None.
|
||||||
failure_message: Optional[str] = None
|
failure_message: Optional[str] = None
|
||||||
result: Any = None
|
result: Optional[UtilityResult] = None
|
||||||
|
|
||||||
|
|
||||||
class EngineCoreOutputs(
|
class EngineCoreOutputs(
|
||||||
|
|||||||
@ -36,7 +36,7 @@ from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
|
|||||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
||||||
EngineCoreRequestType,
|
EngineCoreRequestType,
|
||||||
ReconfigureDistributedRequest, ReconfigureRankType,
|
ReconfigureDistributedRequest, ReconfigureRankType,
|
||||||
UtilityOutput)
|
UtilityOutput, UtilityResult)
|
||||||
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
|
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
|
||||||
from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses
|
from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses
|
||||||
from vllm.v1.executor.abstract import Executor
|
from vllm.v1.executor.abstract import Executor
|
||||||
@ -715,8 +715,8 @@ class EngineCoreProc(EngineCore):
|
|||||||
output = UtilityOutput(call_id)
|
output = UtilityOutput(call_id)
|
||||||
try:
|
try:
|
||||||
method = getattr(self, method_name)
|
method = getattr(self, method_name)
|
||||||
output.result = method(
|
result = method(*self._convert_msgspec_args(method, args))
|
||||||
*self._convert_msgspec_args(method, args))
|
output.result = UtilityResult(result)
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
logger.exception("Invocation of %s method failed", method_name)
|
logger.exception("Invocation of %s method failed", method_name)
|
||||||
output.failure_message = (f"Call to {method_name} method"
|
output.failure_message = (f"Call to {method_name} method"
|
||||||
|
|||||||
@ -552,7 +552,8 @@ def _process_utility_output(output: UtilityOutput,
|
|||||||
if output.failure_message is not None:
|
if output.failure_message is not None:
|
||||||
future.set_exception(Exception(output.failure_message))
|
future.set_exception(Exception(output.failure_message))
|
||||||
else:
|
else:
|
||||||
future.set_result(output.result)
|
assert output.result is not None
|
||||||
|
future.set_result(output.result.result)
|
||||||
|
|
||||||
|
|
||||||
class SyncMPClient(MPClient):
|
class SyncMPClient(MPClient):
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
import importlib
|
||||||
import pickle
|
import pickle
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from inspect import isclass
|
from inspect import isclass
|
||||||
@ -9,6 +10,7 @@ from types import FunctionType
|
|||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import cloudpickle
|
import cloudpickle
|
||||||
|
import msgspec
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import zmq
|
import zmq
|
||||||
@ -22,6 +24,7 @@ from vllm.multimodal.inputs import (BaseMultiModalField,
|
|||||||
MultiModalFlatField, MultiModalKwargs,
|
MultiModalFlatField, MultiModalKwargs,
|
||||||
MultiModalKwargsItem,
|
MultiModalKwargsItem,
|
||||||
MultiModalSharedField, NestedTensors)
|
MultiModalSharedField, NestedTensors)
|
||||||
|
from vllm.v1.engine import UtilityResult
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -46,6 +49,10 @@ def _log_insecure_serialization_warning():
|
|||||||
"VLLM_ALLOW_INSECURE_SERIALIZATION=1")
|
"VLLM_ALLOW_INSECURE_SERIALIZATION=1")
|
||||||
|
|
||||||
|
|
||||||
|
def _typestr(t: type):
|
||||||
|
return t.__module__, t.__qualname__
|
||||||
|
|
||||||
|
|
||||||
class MsgpackEncoder:
|
class MsgpackEncoder:
|
||||||
"""Encoder with custom torch tensor and numpy array serialization.
|
"""Encoder with custom torch tensor and numpy array serialization.
|
||||||
|
|
||||||
@ -122,6 +129,18 @@ class MsgpackEncoder:
|
|||||||
for itemlist in mm._items_by_modality.values()
|
for itemlist in mm._items_by_modality.values()
|
||||||
for item in itemlist]
|
for item in itemlist]
|
||||||
|
|
||||||
|
if isinstance(obj, UtilityResult):
|
||||||
|
result = obj.result
|
||||||
|
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION or result is None:
|
||||||
|
return None, result
|
||||||
|
# Since utility results are not strongly typed, we also encode
|
||||||
|
# the type (or a list of types in the case it's a list) to
|
||||||
|
# help with correct msgspec deserialization.
|
||||||
|
cls = result.__class__
|
||||||
|
return _typestr(cls) if cls is not list else [
|
||||||
|
_typestr(type(v)) for v in result
|
||||||
|
], result
|
||||||
|
|
||||||
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
|
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
|
||||||
raise TypeError(f"Object of type {type(obj)} is not serializable"
|
raise TypeError(f"Object of type {type(obj)} is not serializable"
|
||||||
"Set VLLM_ALLOW_INSECURE_SERIALIZATION=1 to allow "
|
"Set VLLM_ALLOW_INSECURE_SERIALIZATION=1 to allow "
|
||||||
@ -237,8 +256,33 @@ class MsgpackDecoder:
|
|||||||
k: self._decode_nested_tensors(v)
|
k: self._decode_nested_tensors(v)
|
||||||
for k, v in obj.items()
|
for k, v in obj.items()
|
||||||
})
|
})
|
||||||
|
if t is UtilityResult:
|
||||||
|
return self._decode_utility_result(obj)
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
def _decode_utility_result(self, obj: Any) -> UtilityResult:
|
||||||
|
result_type, result = obj
|
||||||
|
if result_type is not None:
|
||||||
|
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
|
||||||
|
raise TypeError("VLLM_ALLOW_INSECURE_SERIALIZATION must "
|
||||||
|
"be set to use custom utility result types")
|
||||||
|
assert isinstance(result_type, list)
|
||||||
|
if len(result_type) == 2 and isinstance(result_type[0], str):
|
||||||
|
result = self._convert_result(result_type, result)
|
||||||
|
else:
|
||||||
|
assert isinstance(result, list)
|
||||||
|
result = [
|
||||||
|
self._convert_result(rt, r)
|
||||||
|
for rt, r in zip(result_type, result)
|
||||||
|
]
|
||||||
|
return UtilityResult(result)
|
||||||
|
|
||||||
|
def _convert_result(self, result_type: Sequence[str], result: Any):
|
||||||
|
mod_name, name = result_type
|
||||||
|
mod = importlib.import_module(mod_name)
|
||||||
|
result_type = getattr(mod, name)
|
||||||
|
return msgspec.convert(result, result_type, dec_hook=self.dec_hook)
|
||||||
|
|
||||||
def _decode_ndarray(self, arr: Any) -> np.ndarray:
|
def _decode_ndarray(self, arr: Any) -> np.ndarray:
|
||||||
dtype, shape, data = arr
|
dtype, shape, data = arr
|
||||||
# zero-copy decode. We assume the ndarray will not be kept around,
|
# zero-copy decode. We assume the ndarray will not be kept around,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user