[Misc] Support more collective_rpc return types (#21845)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-07-30 18:20:20 +01:00 committed by GitHub
parent 8f0d516715
commit 56bd537dde
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 121 additions and 6 deletions

View File

@ -6,8 +6,9 @@ import os
import signal
import time
import uuid
from dataclasses import dataclass
from threading import Thread
from typing import Optional
from typing import Optional, Union
from unittest.mock import MagicMock
import pytest
@ -292,6 +293,68 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
client.shutdown()
@dataclass
class MyDataclass:
message: str
# Dummy utility function to monkey-patch into engine core.
def echo_dc(
self,
msg: str,
return_list: bool = False,
) -> Union[MyDataclass, list[MyDataclass]]:
print(f"echo dc util function called: {msg}")
# Return dataclass to verify support for returning custom types
# (for which there is special handling to make it work with msgspec).
return [MyDataclass(msg) for _ in range(3)] if return_list \
else MyDataclass(msg)
@pytest.mark.asyncio(loop_scope="function")
async def test_engine_core_client_util_method_custom_return(
monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
# Must set insecure serialization to allow returning custom types.
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
# Monkey-patch core engine utility function to test.
m.setattr(EngineCore, "echo_dc", echo_dc, raising=False)
engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
vllm_config = engine_args.create_engine_config(
usage_context=UsageContext.UNKNOWN_CONTEXT)
executor_class = Executor.get_class(vllm_config)
with set_default_torch_num_threads(1):
client = EngineCoreClient.make_client(
multiprocess_mode=True,
asyncio_mode=True,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=True,
)
try:
# Test utility method returning custom / non-native data type.
core_client: AsyncMPClient = client
result = await core_client.call_utility_async(
"echo_dc", "testarg2", False)
assert isinstance(result,
MyDataclass) and result.message == "testarg2"
result = await core_client.call_utility_async(
"echo_dc", "testarg2", True)
assert isinstance(result, list) and all(
isinstance(r, MyDataclass) and r.message == "testarg2"
for r in result)
finally:
client.shutdown()
@pytest.mark.parametrize(
"multiprocessing_mode,publisher_config",
[(True, "tcp"), (False, "inproc")],

View File

@ -123,6 +123,13 @@ class EngineCoreOutput(
return self.finish_reason is not None
class UtilityResult:
"""Wrapper for special handling when serializing/deserializing."""
def __init__(self, r: Any = None):
self.result = r
class UtilityOutput(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
@ -132,7 +139,7 @@ class UtilityOutput(
# Non-None implies the call failed, result should be None.
failure_message: Optional[str] = None
result: Any = None
result: Optional[UtilityResult] = None
class EngineCoreOutputs(

View File

@ -36,7 +36,7 @@ from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType,
ReconfigureDistributedRequest, ReconfigureRankType,
UtilityOutput)
UtilityOutput, UtilityResult)
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses
from vllm.v1.executor.abstract import Executor
@ -715,8 +715,8 @@ class EngineCoreProc(EngineCore):
output = UtilityOutput(call_id)
try:
method = getattr(self, method_name)
output.result = method(
*self._convert_msgspec_args(method, args))
result = method(*self._convert_msgspec_args(method, args))
output.result = UtilityResult(result)
except BaseException as e:
logger.exception("Invocation of %s method failed", method_name)
output.failure_message = (f"Call to {method_name} method"

View File

@ -552,7 +552,8 @@ def _process_utility_output(output: UtilityOutput,
if output.failure_message is not None:
future.set_exception(Exception(output.failure_message))
else:
future.set_result(output.result)
assert output.result is not None
future.set_result(output.result.result)
class SyncMPClient(MPClient):

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import importlib
import pickle
from collections.abc import Sequence
from inspect import isclass
@ -9,6 +10,7 @@ from types import FunctionType
from typing import Any, Optional, Union
import cloudpickle
import msgspec
import numpy as np
import torch
import zmq
@ -22,6 +24,7 @@ from vllm.multimodal.inputs import (BaseMultiModalField,
MultiModalFlatField, MultiModalKwargs,
MultiModalKwargsItem,
MultiModalSharedField, NestedTensors)
from vllm.v1.engine import UtilityResult
logger = init_logger(__name__)
@ -46,6 +49,10 @@ def _log_insecure_serialization_warning():
"VLLM_ALLOW_INSECURE_SERIALIZATION=1")
def _typestr(t: type):
return t.__module__, t.__qualname__
class MsgpackEncoder:
"""Encoder with custom torch tensor and numpy array serialization.
@ -122,6 +129,18 @@ class MsgpackEncoder:
for itemlist in mm._items_by_modality.values()
for item in itemlist]
if isinstance(obj, UtilityResult):
result = obj.result
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION or result is None:
return None, result
# Since utility results are not strongly typed, we also encode
# the type (or a list of types in the case it's a list) to
# help with correct msgspec deserialization.
cls = result.__class__
return _typestr(cls) if cls is not list else [
_typestr(type(v)) for v in result
], result
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
raise TypeError(f"Object of type {type(obj)} is not serializable"
"Set VLLM_ALLOW_INSECURE_SERIALIZATION=1 to allow "
@ -237,8 +256,33 @@ class MsgpackDecoder:
k: self._decode_nested_tensors(v)
for k, v in obj.items()
})
if t is UtilityResult:
return self._decode_utility_result(obj)
return obj
def _decode_utility_result(self, obj: Any) -> UtilityResult:
result_type, result = obj
if result_type is not None:
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
raise TypeError("VLLM_ALLOW_INSECURE_SERIALIZATION must "
"be set to use custom utility result types")
assert isinstance(result_type, list)
if len(result_type) == 2 and isinstance(result_type[0], str):
result = self._convert_result(result_type, result)
else:
assert isinstance(result, list)
result = [
self._convert_result(rt, r)
for rt, r in zip(result_type, result)
]
return UtilityResult(result)
def _convert_result(self, result_type: Sequence[str], result: Any):
mod_name, name = result_type
mod = importlib.import_module(mod_name)
result_type = getattr(mod, name)
return msgspec.convert(result, result_type, dec_hook=self.dec_hook)
def _decode_ndarray(self, arr: Any) -> np.ndarray:
dtype, shape, data = arr
# zero-copy decode. We assume the ndarray will not be kept around,