mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-25 03:45:01 +08:00
[Misc] Support more collective_rpc return types (#25294)
Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
123e7ad492
commit
ea01b17b6f
@ -8,7 +8,7 @@ import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from threading import Thread
|
||||
from typing import Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
@ -331,6 +331,46 @@ def echo_dc(
|
||||
return [val for _ in range(3)] if return_list else val
|
||||
|
||||
|
||||
# Dummy utility function to test dict serialization with custom types.
|
||||
def echo_dc_dict(
|
||||
self,
|
||||
msg: str,
|
||||
return_dict: bool = False,
|
||||
) -> Union[MyDataclass, dict[str, MyDataclass]]:
|
||||
print(f"echo dc dict util function called: {msg}")
|
||||
val = None if msg is None else MyDataclass(msg)
|
||||
# Return dict of dataclasses to verify support for returning dicts
|
||||
# with custom value types.
|
||||
if return_dict:
|
||||
return {"key1": val, "key2": val, "key3": val}
|
||||
else:
|
||||
return val
|
||||
|
||||
|
||||
# Dummy utility function to test nested structures with custom types.
|
||||
def echo_dc_nested(
|
||||
self,
|
||||
msg: str,
|
||||
structure_type: str = "list_of_dicts",
|
||||
) -> Any:
|
||||
print(f"echo dc nested util function called: {msg}, "
|
||||
f"structure: {structure_type}")
|
||||
val = None if msg is None else MyDataclass(msg)
|
||||
|
||||
if structure_type == "list_of_dicts": # noqa
|
||||
# Return list of dicts: [{"a": val, "b": val}, {"c": val, "d": val}]
|
||||
return [{"a": val, "b": val}, {"c": val, "d": val}]
|
||||
elif structure_type == "dict_of_lists":
|
||||
# Return dict of lists: {"list1": [val, val], "list2": [val, val]}
|
||||
return {"list1": [val, val], "list2": [val, val]}
|
||||
elif structure_type == "deep_nested":
|
||||
# Return deeply nested: {"outer": [{"inner": [val, val]},
|
||||
# {"inner": [val]}]}
|
||||
return {"outer": [{"inner": [val, val]}, {"inner": [val]}]}
|
||||
else:
|
||||
return val
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="function")
|
||||
async def test_engine_core_client_util_method_custom_return(
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
@ -384,6 +424,167 @@ async def test_engine_core_client_util_method_custom_return(
|
||||
client.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="function")
|
||||
async def test_engine_core_client_util_method_custom_dict_return(
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
# Must set insecure serialization to allow returning custom types.
|
||||
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
# Monkey-patch core engine utility function to test.
|
||||
m.setattr(EngineCore, "echo_dc_dict", echo_dc_dict, raising=False)
|
||||
|
||||
engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
|
||||
vllm_config = engine_args.create_engine_config(
|
||||
usage_context=UsageContext.UNKNOWN_CONTEXT)
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
|
||||
with set_default_torch_num_threads(1):
|
||||
client = EngineCoreClient.make_client(
|
||||
multiprocess_mode=True,
|
||||
asyncio_mode=True,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=True,
|
||||
)
|
||||
|
||||
try:
|
||||
# Test utility method returning custom / non-native data type.
|
||||
core_client: AsyncMPClient = client
|
||||
|
||||
# Test single object return
|
||||
result = await core_client.call_utility_async(
|
||||
"echo_dc_dict", "testarg3", False)
|
||||
assert isinstance(result,
|
||||
MyDataclass) and result.message == "testarg3"
|
||||
|
||||
# Test dict return with custom value types
|
||||
result = await core_client.call_utility_async(
|
||||
"echo_dc_dict", "testarg3", True)
|
||||
assert isinstance(result, dict) and len(result) == 3
|
||||
for key, val in result.items():
|
||||
assert key in ["key1", "key2", "key3"]
|
||||
assert isinstance(val,
|
||||
MyDataclass) and val.message == "testarg3"
|
||||
|
||||
# Test returning dict with None values
|
||||
result = await core_client.call_utility_async(
|
||||
"echo_dc_dict", None, True)
|
||||
assert isinstance(result, dict) and len(result) == 3
|
||||
for key, val in result.items():
|
||||
assert key in ["key1", "key2", "key3"]
|
||||
assert val is None
|
||||
|
||||
finally:
|
||||
client.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="function")
|
||||
async def test_engine_core_client_util_method_nested_structures(
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
# Must set insecure serialization to allow returning custom types.
|
||||
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
# Monkey-patch core engine utility function to test.
|
||||
m.setattr(EngineCore, "echo_dc_nested", echo_dc_nested, raising=False)
|
||||
|
||||
engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
|
||||
vllm_config = engine_args.create_engine_config(
|
||||
usage_context=UsageContext.UNKNOWN_CONTEXT)
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
|
||||
with set_default_torch_num_threads(1):
|
||||
client = EngineCoreClient.make_client(
|
||||
multiprocess_mode=True,
|
||||
asyncio_mode=True,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=True,
|
||||
)
|
||||
|
||||
try:
|
||||
core_client: AsyncMPClient = client
|
||||
|
||||
# Test list of dicts: [{"a": val, "b": val}, {"c": val, "d": val}]
|
||||
result = await core_client.call_utility_async(
|
||||
"echo_dc_nested", "nested1", "list_of_dicts")
|
||||
assert isinstance(result, list) and len(result) == 2
|
||||
for i, item in enumerate(result):
|
||||
assert isinstance(item, dict)
|
||||
if i == 0:
|
||||
assert "a" in item and "b" in item
|
||||
assert isinstance(
|
||||
item["a"],
|
||||
MyDataclass) and item["a"].message == "nested1"
|
||||
assert isinstance(
|
||||
item["b"],
|
||||
MyDataclass) and item["b"].message == "nested1"
|
||||
else:
|
||||
assert "c" in item and "d" in item
|
||||
assert isinstance(
|
||||
item["c"],
|
||||
MyDataclass) and item["c"].message == "nested1"
|
||||
assert isinstance(
|
||||
item["d"],
|
||||
MyDataclass) and item["d"].message == "nested1"
|
||||
|
||||
# Test dict of lists: {"list1": [val, val], "list2": [val, val]}
|
||||
result = await core_client.call_utility_async(
|
||||
"echo_dc_nested", "nested2", "dict_of_lists")
|
||||
assert isinstance(result, dict) and len(result) == 2
|
||||
assert "list1" in result and "list2" in result
|
||||
for key, lst in result.items():
|
||||
assert isinstance(lst, list) and len(lst) == 2
|
||||
for item in lst:
|
||||
assert isinstance(
|
||||
item, MyDataclass) and item.message == "nested2"
|
||||
|
||||
# Test deeply nested: {"outer": [{"inner": [val, val]},
|
||||
# {"inner": [val]}]}
|
||||
result = await core_client.call_utility_async(
|
||||
"echo_dc_nested", "nested3", "deep_nested")
|
||||
assert isinstance(result, dict) and "outer" in result
|
||||
outer_list = result["outer"]
|
||||
assert isinstance(outer_list, list) and len(outer_list) == 2
|
||||
|
||||
# First dict in outer list should have "inner" with 2 items
|
||||
inner_dict1 = outer_list[0]
|
||||
assert isinstance(inner_dict1, dict) and "inner" in inner_dict1
|
||||
inner_list1 = inner_dict1["inner"]
|
||||
assert isinstance(inner_list1, list) and len(inner_list1) == 2
|
||||
for item in inner_list1:
|
||||
assert isinstance(item,
|
||||
MyDataclass) and item.message == "nested3"
|
||||
|
||||
# Second dict in outer list should have "inner" with 1 item
|
||||
inner_dict2 = outer_list[1]
|
||||
assert isinstance(inner_dict2, dict) and "inner" in inner_dict2
|
||||
inner_list2 = inner_dict2["inner"]
|
||||
assert isinstance(inner_list2, list) and len(inner_list2) == 1
|
||||
assert isinstance(
|
||||
inner_list2[0],
|
||||
MyDataclass) and inner_list2[0].message == "nested3"
|
||||
|
||||
# Test with None values in nested structures
|
||||
result = await core_client.call_utility_async(
|
||||
"echo_dc_nested", None, "list_of_dicts")
|
||||
assert isinstance(result, list) and len(result) == 2
|
||||
for item in result:
|
||||
assert isinstance(item, dict)
|
||||
for val in item.values():
|
||||
assert val is None
|
||||
|
||||
finally:
|
||||
client.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"multiprocessing_mode,publisher_config",
|
||||
[(True, "tcp"), (False, "inproc")],
|
||||
|
||||
@ -7,7 +7,7 @@ import pickle
|
||||
from collections.abc import Sequence
|
||||
from inspect import isclass
|
||||
from types import FunctionType
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import cloudpickle
|
||||
import msgspec
|
||||
@ -59,6 +59,42 @@ def _typestr(val: Any) -> Optional[tuple[str, str]]:
|
||||
return t.__module__, t.__qualname__
|
||||
|
||||
|
||||
def _encode_type_info_recursive(obj: Any) -> Any:
|
||||
"""Recursively encode type information for nested structures of
|
||||
lists/dicts."""
|
||||
if obj is None:
|
||||
return None
|
||||
if type(obj) is list:
|
||||
return [_encode_type_info_recursive(item) for item in obj]
|
||||
if type(obj) is dict:
|
||||
return {k: _encode_type_info_recursive(v) for k, v in obj.items()}
|
||||
return _typestr(obj)
|
||||
|
||||
|
||||
def _decode_type_info_recursive(
|
||||
type_info: Any, data: Any, convert_fn: Callable[[Sequence[str], Any],
|
||||
Any]) -> Any:
|
||||
"""Recursively decode type information for nested structures of
|
||||
lists/dicts."""
|
||||
if type_info is None:
|
||||
return data
|
||||
if isinstance(type_info, dict):
|
||||
assert isinstance(data, dict)
|
||||
return {
|
||||
k: _decode_type_info_recursive(type_info[k], data[k], convert_fn)
|
||||
for k in type_info
|
||||
}
|
||||
if isinstance(type_info, list) and (
|
||||
# Exclude serialized tensors/numpy arrays.
|
||||
len(type_info) != 2 or not isinstance(type_info[0], str)):
|
||||
assert isinstance(data, list)
|
||||
return [
|
||||
_decode_type_info_recursive(ti, d, convert_fn)
|
||||
for ti, d in zip(type_info, data)
|
||||
]
|
||||
return convert_fn(type_info, data)
|
||||
|
||||
|
||||
class MsgpackEncoder:
|
||||
"""Encoder with custom torch tensor and numpy array serialization.
|
||||
|
||||
@ -129,12 +165,10 @@ class MsgpackEncoder:
|
||||
result = obj.result
|
||||
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
|
||||
return None, result
|
||||
# Since utility results are not strongly typed, we also encode
|
||||
# the type (or a list of types in the case it's a list) to
|
||||
# help with correct msgspec deserialization.
|
||||
return _typestr(result) if type(result) is not list else [
|
||||
_typestr(v) for v in result
|
||||
], result
|
||||
# Since utility results are not strongly typed, we recursively
|
||||
# encode type information for nested structures of lists/dicts
|
||||
# to help with correct msgspec deserialization.
|
||||
return _encode_type_info_recursive(result), result
|
||||
|
||||
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
|
||||
raise TypeError(f"Object of type {type(obj)} is not serializable"
|
||||
@ -288,15 +322,9 @@ class MsgpackDecoder:
|
||||
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
|
||||
raise TypeError("VLLM_ALLOW_INSECURE_SERIALIZATION must "
|
||||
"be set to use custom utility result types")
|
||||
assert isinstance(result_type, list)
|
||||
if len(result_type) == 2 and isinstance(result_type[0], str):
|
||||
result = self._convert_result(result_type, result)
|
||||
else:
|
||||
assert isinstance(result, list)
|
||||
result = [
|
||||
self._convert_result(rt, r)
|
||||
for rt, r in zip(result_type, result)
|
||||
]
|
||||
# Use recursive decoding to handle nested structures
|
||||
result = _decode_type_info_recursive(result_type, result,
|
||||
self._convert_result)
|
||||
return UtilityResult(result)
|
||||
|
||||
def _convert_result(self, result_type: Sequence[str], result: Any) -> Any:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user