mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-23 02:44:26 +08:00
[Misc] Support more collective_rpc return types (#25294)
Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
123e7ad492
commit
ea01b17b6f
@ -8,7 +8,7 @@ import time
|
|||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Optional, Union
|
from typing import Any, Optional, Union
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -331,6 +331,46 @@ def echo_dc(
|
|||||||
return [val for _ in range(3)] if return_list else val
|
return [val for _ in range(3)] if return_list else val
|
||||||
|
|
||||||
|
|
||||||
|
# Dummy utility function to test dict serialization with custom types.
|
||||||
|
def echo_dc_dict(
|
||||||
|
self,
|
||||||
|
msg: str,
|
||||||
|
return_dict: bool = False,
|
||||||
|
) -> Union[MyDataclass, dict[str, MyDataclass]]:
|
||||||
|
print(f"echo dc dict util function called: {msg}")
|
||||||
|
val = None if msg is None else MyDataclass(msg)
|
||||||
|
# Return dict of dataclasses to verify support for returning dicts
|
||||||
|
# with custom value types.
|
||||||
|
if return_dict:
|
||||||
|
return {"key1": val, "key2": val, "key3": val}
|
||||||
|
else:
|
||||||
|
return val
|
||||||
|
|
||||||
|
|
||||||
|
# Dummy utility function to test nested structures with custom types.
|
||||||
|
def echo_dc_nested(
|
||||||
|
self,
|
||||||
|
msg: str,
|
||||||
|
structure_type: str = "list_of_dicts",
|
||||||
|
) -> Any:
|
||||||
|
print(f"echo dc nested util function called: {msg}, "
|
||||||
|
f"structure: {structure_type}")
|
||||||
|
val = None if msg is None else MyDataclass(msg)
|
||||||
|
|
||||||
|
if structure_type == "list_of_dicts": # noqa
|
||||||
|
# Return list of dicts: [{"a": val, "b": val}, {"c": val, "d": val}]
|
||||||
|
return [{"a": val, "b": val}, {"c": val, "d": val}]
|
||||||
|
elif structure_type == "dict_of_lists":
|
||||||
|
# Return dict of lists: {"list1": [val, val], "list2": [val, val]}
|
||||||
|
return {"list1": [val, val], "list2": [val, val]}
|
||||||
|
elif structure_type == "deep_nested":
|
||||||
|
# Return deeply nested: {"outer": [{"inner": [val, val]},
|
||||||
|
# {"inner": [val]}]}
|
||||||
|
return {"outer": [{"inner": [val, val]}, {"inner": [val]}]}
|
||||||
|
else:
|
||||||
|
return val
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="function")
|
@pytest.mark.asyncio(loop_scope="function")
|
||||||
async def test_engine_core_client_util_method_custom_return(
|
async def test_engine_core_client_util_method_custom_return(
|
||||||
monkeypatch: pytest.MonkeyPatch):
|
monkeypatch: pytest.MonkeyPatch):
|
||||||
@ -384,6 +424,167 @@ async def test_engine_core_client_util_method_custom_return(
|
|||||||
client.shutdown()
|
client.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="function")
|
||||||
|
async def test_engine_core_client_util_method_custom_dict_return(
|
||||||
|
monkeypatch: pytest.MonkeyPatch):
|
||||||
|
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
|
||||||
|
# Must set insecure serialization to allow returning custom types.
|
||||||
|
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||||
|
|
||||||
|
# Monkey-patch core engine utility function to test.
|
||||||
|
m.setattr(EngineCore, "echo_dc_dict", echo_dc_dict, raising=False)
|
||||||
|
|
||||||
|
engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
|
||||||
|
vllm_config = engine_args.create_engine_config(
|
||||||
|
usage_context=UsageContext.UNKNOWN_CONTEXT)
|
||||||
|
executor_class = Executor.get_class(vllm_config)
|
||||||
|
|
||||||
|
with set_default_torch_num_threads(1):
|
||||||
|
client = EngineCoreClient.make_client(
|
||||||
|
multiprocess_mode=True,
|
||||||
|
asyncio_mode=True,
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
executor_class=executor_class,
|
||||||
|
log_stats=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Test utility method returning custom / non-native data type.
|
||||||
|
core_client: AsyncMPClient = client
|
||||||
|
|
||||||
|
# Test single object return
|
||||||
|
result = await core_client.call_utility_async(
|
||||||
|
"echo_dc_dict", "testarg3", False)
|
||||||
|
assert isinstance(result,
|
||||||
|
MyDataclass) and result.message == "testarg3"
|
||||||
|
|
||||||
|
# Test dict return with custom value types
|
||||||
|
result = await core_client.call_utility_async(
|
||||||
|
"echo_dc_dict", "testarg3", True)
|
||||||
|
assert isinstance(result, dict) and len(result) == 3
|
||||||
|
for key, val in result.items():
|
||||||
|
assert key in ["key1", "key2", "key3"]
|
||||||
|
assert isinstance(val,
|
||||||
|
MyDataclass) and val.message == "testarg3"
|
||||||
|
|
||||||
|
# Test returning dict with None values
|
||||||
|
result = await core_client.call_utility_async(
|
||||||
|
"echo_dc_dict", None, True)
|
||||||
|
assert isinstance(result, dict) and len(result) == 3
|
||||||
|
for key, val in result.items():
|
||||||
|
assert key in ["key1", "key2", "key3"]
|
||||||
|
assert val is None
|
||||||
|
|
||||||
|
finally:
|
||||||
|
client.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="function")
|
||||||
|
async def test_engine_core_client_util_method_nested_structures(
|
||||||
|
monkeypatch: pytest.MonkeyPatch):
|
||||||
|
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
|
||||||
|
# Must set insecure serialization to allow returning custom types.
|
||||||
|
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||||
|
|
||||||
|
# Monkey-patch core engine utility function to test.
|
||||||
|
m.setattr(EngineCore, "echo_dc_nested", echo_dc_nested, raising=False)
|
||||||
|
|
||||||
|
engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
|
||||||
|
vllm_config = engine_args.create_engine_config(
|
||||||
|
usage_context=UsageContext.UNKNOWN_CONTEXT)
|
||||||
|
executor_class = Executor.get_class(vllm_config)
|
||||||
|
|
||||||
|
with set_default_torch_num_threads(1):
|
||||||
|
client = EngineCoreClient.make_client(
|
||||||
|
multiprocess_mode=True,
|
||||||
|
asyncio_mode=True,
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
executor_class=executor_class,
|
||||||
|
log_stats=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
core_client: AsyncMPClient = client
|
||||||
|
|
||||||
|
# Test list of dicts: [{"a": val, "b": val}, {"c": val, "d": val}]
|
||||||
|
result = await core_client.call_utility_async(
|
||||||
|
"echo_dc_nested", "nested1", "list_of_dicts")
|
||||||
|
assert isinstance(result, list) and len(result) == 2
|
||||||
|
for i, item in enumerate(result):
|
||||||
|
assert isinstance(item, dict)
|
||||||
|
if i == 0:
|
||||||
|
assert "a" in item and "b" in item
|
||||||
|
assert isinstance(
|
||||||
|
item["a"],
|
||||||
|
MyDataclass) and item["a"].message == "nested1"
|
||||||
|
assert isinstance(
|
||||||
|
item["b"],
|
||||||
|
MyDataclass) and item["b"].message == "nested1"
|
||||||
|
else:
|
||||||
|
assert "c" in item and "d" in item
|
||||||
|
assert isinstance(
|
||||||
|
item["c"],
|
||||||
|
MyDataclass) and item["c"].message == "nested1"
|
||||||
|
assert isinstance(
|
||||||
|
item["d"],
|
||||||
|
MyDataclass) and item["d"].message == "nested1"
|
||||||
|
|
||||||
|
# Test dict of lists: {"list1": [val, val], "list2": [val, val]}
|
||||||
|
result = await core_client.call_utility_async(
|
||||||
|
"echo_dc_nested", "nested2", "dict_of_lists")
|
||||||
|
assert isinstance(result, dict) and len(result) == 2
|
||||||
|
assert "list1" in result and "list2" in result
|
||||||
|
for key, lst in result.items():
|
||||||
|
assert isinstance(lst, list) and len(lst) == 2
|
||||||
|
for item in lst:
|
||||||
|
assert isinstance(
|
||||||
|
item, MyDataclass) and item.message == "nested2"
|
||||||
|
|
||||||
|
# Test deeply nested: {"outer": [{"inner": [val, val]},
|
||||||
|
# {"inner": [val]}]}
|
||||||
|
result = await core_client.call_utility_async(
|
||||||
|
"echo_dc_nested", "nested3", "deep_nested")
|
||||||
|
assert isinstance(result, dict) and "outer" in result
|
||||||
|
outer_list = result["outer"]
|
||||||
|
assert isinstance(outer_list, list) and len(outer_list) == 2
|
||||||
|
|
||||||
|
# First dict in outer list should have "inner" with 2 items
|
||||||
|
inner_dict1 = outer_list[0]
|
||||||
|
assert isinstance(inner_dict1, dict) and "inner" in inner_dict1
|
||||||
|
inner_list1 = inner_dict1["inner"]
|
||||||
|
assert isinstance(inner_list1, list) and len(inner_list1) == 2
|
||||||
|
for item in inner_list1:
|
||||||
|
assert isinstance(item,
|
||||||
|
MyDataclass) and item.message == "nested3"
|
||||||
|
|
||||||
|
# Second dict in outer list should have "inner" with 1 item
|
||||||
|
inner_dict2 = outer_list[1]
|
||||||
|
assert isinstance(inner_dict2, dict) and "inner" in inner_dict2
|
||||||
|
inner_list2 = inner_dict2["inner"]
|
||||||
|
assert isinstance(inner_list2, list) and len(inner_list2) == 1
|
||||||
|
assert isinstance(
|
||||||
|
inner_list2[0],
|
||||||
|
MyDataclass) and inner_list2[0].message == "nested3"
|
||||||
|
|
||||||
|
# Test with None values in nested structures
|
||||||
|
result = await core_client.call_utility_async(
|
||||||
|
"echo_dc_nested", None, "list_of_dicts")
|
||||||
|
assert isinstance(result, list) and len(result) == 2
|
||||||
|
for item in result:
|
||||||
|
assert isinstance(item, dict)
|
||||||
|
for val in item.values():
|
||||||
|
assert val is None
|
||||||
|
|
||||||
|
finally:
|
||||||
|
client.shutdown()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"multiprocessing_mode,publisher_config",
|
"multiprocessing_mode,publisher_config",
|
||||||
[(True, "tcp"), (False, "inproc")],
|
[(True, "tcp"), (False, "inproc")],
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import pickle
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from inspect import isclass
|
from inspect import isclass
|
||||||
from types import FunctionType
|
from types import FunctionType
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import cloudpickle
|
import cloudpickle
|
||||||
import msgspec
|
import msgspec
|
||||||
@ -59,6 +59,42 @@ def _typestr(val: Any) -> Optional[tuple[str, str]]:
|
|||||||
return t.__module__, t.__qualname__
|
return t.__module__, t.__qualname__
|
||||||
|
|
||||||
|
|
||||||
|
def _encode_type_info_recursive(obj: Any) -> Any:
|
||||||
|
"""Recursively encode type information for nested structures of
|
||||||
|
lists/dicts."""
|
||||||
|
if obj is None:
|
||||||
|
return None
|
||||||
|
if type(obj) is list:
|
||||||
|
return [_encode_type_info_recursive(item) for item in obj]
|
||||||
|
if type(obj) is dict:
|
||||||
|
return {k: _encode_type_info_recursive(v) for k, v in obj.items()}
|
||||||
|
return _typestr(obj)
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_type_info_recursive(
|
||||||
|
type_info: Any, data: Any, convert_fn: Callable[[Sequence[str], Any],
|
||||||
|
Any]) -> Any:
|
||||||
|
"""Recursively decode type information for nested structures of
|
||||||
|
lists/dicts."""
|
||||||
|
if type_info is None:
|
||||||
|
return data
|
||||||
|
if isinstance(type_info, dict):
|
||||||
|
assert isinstance(data, dict)
|
||||||
|
return {
|
||||||
|
k: _decode_type_info_recursive(type_info[k], data[k], convert_fn)
|
||||||
|
for k in type_info
|
||||||
|
}
|
||||||
|
if isinstance(type_info, list) and (
|
||||||
|
# Exclude serialized tensors/numpy arrays.
|
||||||
|
len(type_info) != 2 or not isinstance(type_info[0], str)):
|
||||||
|
assert isinstance(data, list)
|
||||||
|
return [
|
||||||
|
_decode_type_info_recursive(ti, d, convert_fn)
|
||||||
|
for ti, d in zip(type_info, data)
|
||||||
|
]
|
||||||
|
return convert_fn(type_info, data)
|
||||||
|
|
||||||
|
|
||||||
class MsgpackEncoder:
|
class MsgpackEncoder:
|
||||||
"""Encoder with custom torch tensor and numpy array serialization.
|
"""Encoder with custom torch tensor and numpy array serialization.
|
||||||
|
|
||||||
@ -129,12 +165,10 @@ class MsgpackEncoder:
|
|||||||
result = obj.result
|
result = obj.result
|
||||||
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
|
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
|
||||||
return None, result
|
return None, result
|
||||||
# Since utility results are not strongly typed, we also encode
|
# Since utility results are not strongly typed, we recursively
|
||||||
# the type (or a list of types in the case it's a list) to
|
# encode type information for nested structures of lists/dicts
|
||||||
# help with correct msgspec deserialization.
|
# to help with correct msgspec deserialization.
|
||||||
return _typestr(result) if type(result) is not list else [
|
return _encode_type_info_recursive(result), result
|
||||||
_typestr(v) for v in result
|
|
||||||
], result
|
|
||||||
|
|
||||||
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
|
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
|
||||||
raise TypeError(f"Object of type {type(obj)} is not serializable"
|
raise TypeError(f"Object of type {type(obj)} is not serializable"
|
||||||
@ -288,15 +322,9 @@ class MsgpackDecoder:
|
|||||||
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
|
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
|
||||||
raise TypeError("VLLM_ALLOW_INSECURE_SERIALIZATION must "
|
raise TypeError("VLLM_ALLOW_INSECURE_SERIALIZATION must "
|
||||||
"be set to use custom utility result types")
|
"be set to use custom utility result types")
|
||||||
assert isinstance(result_type, list)
|
# Use recursive decoding to handle nested structures
|
||||||
if len(result_type) == 2 and isinstance(result_type[0], str):
|
result = _decode_type_info_recursive(result_type, result,
|
||||||
result = self._convert_result(result_type, result)
|
self._convert_result)
|
||||||
else:
|
|
||||||
assert isinstance(result, list)
|
|
||||||
result = [
|
|
||||||
self._convert_result(rt, r)
|
|
||||||
for rt, r in zip(result_type, result)
|
|
||||||
]
|
|
||||||
return UtilityResult(result)
|
return UtilityResult(result)
|
||||||
|
|
||||||
def _convert_result(self, result_type: Sequence[str], result: Any) -> Any:
|
def _convert_result(self, result_type: Sequence[str], result: Any) -> Any:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user