[Misc] Support more collective_rpc return types (#25294)

Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Nick Hill 2025-09-19 19:02:38 -07:00 committed by yewentao256
parent 123e7ad492
commit ea01b17b6f
2 changed files with 246 additions and 17 deletions

View File

@ -8,7 +8,7 @@ import time
import uuid
from dataclasses import dataclass
from threading import Thread
from typing import Optional, Union
from typing import Any, Optional, Union
from unittest.mock import MagicMock
import pytest
@ -331,6 +331,46 @@ def echo_dc(
return [val for _ in range(3)] if return_list else val
# Dummy utility function to test dict serialization with custom types.
def echo_dc_dict(
self,
msg: str,
return_dict: bool = False,
) -> Union[MyDataclass, dict[str, MyDataclass]]:
print(f"echo dc dict util function called: {msg}")
val = None if msg is None else MyDataclass(msg)
# Return dict of dataclasses to verify support for returning dicts
# with custom value types.
if return_dict:
return {"key1": val, "key2": val, "key3": val}
else:
return val
# Dummy utility function to test nested structures with custom types.
def echo_dc_nested(
self,
msg: str,
structure_type: str = "list_of_dicts",
) -> Any:
print(f"echo dc nested util function called: {msg}, "
f"structure: {structure_type}")
val = None if msg is None else MyDataclass(msg)
if structure_type == "list_of_dicts": # noqa
# Return list of dicts: [{"a": val, "b": val}, {"c": val, "d": val}]
return [{"a": val, "b": val}, {"c": val, "d": val}]
elif structure_type == "dict_of_lists":
# Return dict of lists: {"list1": [val, val], "list2": [val, val]}
return {"list1": [val, val], "list2": [val, val]}
elif structure_type == "deep_nested":
# Return deeply nested: {"outer": [{"inner": [val, val]},
# {"inner": [val]}]}
return {"outer": [{"inner": [val, val]}, {"inner": [val]}]}
else:
return val
@pytest.mark.asyncio(loop_scope="function")
async def test_engine_core_client_util_method_custom_return(
monkeypatch: pytest.MonkeyPatch):
@ -384,6 +424,167 @@ async def test_engine_core_client_util_method_custom_return(
client.shutdown()
@pytest.mark.asyncio(loop_scope="function")
async def test_engine_core_client_util_method_custom_dict_return(
monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
# Must set insecure serialization to allow returning custom types.
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
# Monkey-patch core engine utility function to test.
m.setattr(EngineCore, "echo_dc_dict", echo_dc_dict, raising=False)
engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
vllm_config = engine_args.create_engine_config(
usage_context=UsageContext.UNKNOWN_CONTEXT)
executor_class = Executor.get_class(vllm_config)
with set_default_torch_num_threads(1):
client = EngineCoreClient.make_client(
multiprocess_mode=True,
asyncio_mode=True,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=True,
)
try:
# Test utility method returning custom / non-native data type.
core_client: AsyncMPClient = client
# Test single object return
result = await core_client.call_utility_async(
"echo_dc_dict", "testarg3", False)
assert isinstance(result,
MyDataclass) and result.message == "testarg3"
# Test dict return with custom value types
result = await core_client.call_utility_async(
"echo_dc_dict", "testarg3", True)
assert isinstance(result, dict) and len(result) == 3
for key, val in result.items():
assert key in ["key1", "key2", "key3"]
assert isinstance(val,
MyDataclass) and val.message == "testarg3"
# Test returning dict with None values
result = await core_client.call_utility_async(
"echo_dc_dict", None, True)
assert isinstance(result, dict) and len(result) == 3
for key, val in result.items():
assert key in ["key1", "key2", "key3"]
assert val is None
finally:
client.shutdown()
@pytest.mark.asyncio(loop_scope="function")
async def test_engine_core_client_util_method_nested_structures(
monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
# Must set insecure serialization to allow returning custom types.
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
# Monkey-patch core engine utility function to test.
m.setattr(EngineCore, "echo_dc_nested", echo_dc_nested, raising=False)
engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
vllm_config = engine_args.create_engine_config(
usage_context=UsageContext.UNKNOWN_CONTEXT)
executor_class = Executor.get_class(vllm_config)
with set_default_torch_num_threads(1):
client = EngineCoreClient.make_client(
multiprocess_mode=True,
asyncio_mode=True,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=True,
)
try:
core_client: AsyncMPClient = client
# Test list of dicts: [{"a": val, "b": val}, {"c": val, "d": val}]
result = await core_client.call_utility_async(
"echo_dc_nested", "nested1", "list_of_dicts")
assert isinstance(result, list) and len(result) == 2
for i, item in enumerate(result):
assert isinstance(item, dict)
if i == 0:
assert "a" in item and "b" in item
assert isinstance(
item["a"],
MyDataclass) and item["a"].message == "nested1"
assert isinstance(
item["b"],
MyDataclass) and item["b"].message == "nested1"
else:
assert "c" in item and "d" in item
assert isinstance(
item["c"],
MyDataclass) and item["c"].message == "nested1"
assert isinstance(
item["d"],
MyDataclass) and item["d"].message == "nested1"
# Test dict of lists: {"list1": [val, val], "list2": [val, val]}
result = await core_client.call_utility_async(
"echo_dc_nested", "nested2", "dict_of_lists")
assert isinstance(result, dict) and len(result) == 2
assert "list1" in result and "list2" in result
for key, lst in result.items():
assert isinstance(lst, list) and len(lst) == 2
for item in lst:
assert isinstance(
item, MyDataclass) and item.message == "nested2"
# Test deeply nested: {"outer": [{"inner": [val, val]},
# {"inner": [val]}]}
result = await core_client.call_utility_async(
"echo_dc_nested", "nested3", "deep_nested")
assert isinstance(result, dict) and "outer" in result
outer_list = result["outer"]
assert isinstance(outer_list, list) and len(outer_list) == 2
# First dict in outer list should have "inner" with 2 items
inner_dict1 = outer_list[0]
assert isinstance(inner_dict1, dict) and "inner" in inner_dict1
inner_list1 = inner_dict1["inner"]
assert isinstance(inner_list1, list) and len(inner_list1) == 2
for item in inner_list1:
assert isinstance(item,
MyDataclass) and item.message == "nested3"
# Second dict in outer list should have "inner" with 1 item
inner_dict2 = outer_list[1]
assert isinstance(inner_dict2, dict) and "inner" in inner_dict2
inner_list2 = inner_dict2["inner"]
assert isinstance(inner_list2, list) and len(inner_list2) == 1
assert isinstance(
inner_list2[0],
MyDataclass) and inner_list2[0].message == "nested3"
# Test with None values in nested structures
result = await core_client.call_utility_async(
"echo_dc_nested", None, "list_of_dicts")
assert isinstance(result, list) and len(result) == 2
for item in result:
assert isinstance(item, dict)
for val in item.values():
assert val is None
finally:
client.shutdown()
@pytest.mark.parametrize(
"multiprocessing_mode,publisher_config",
[(True, "tcp"), (False, "inproc")],

View File

@ -7,7 +7,7 @@ import pickle
from collections.abc import Sequence
from inspect import isclass
from types import FunctionType
from typing import Any, Optional, Union
from typing import Any, Callable, Optional, Union
import cloudpickle
import msgspec
@ -59,6 +59,42 @@ def _typestr(val: Any) -> Optional[tuple[str, str]]:
return t.__module__, t.__qualname__
def _encode_type_info_recursive(obj: Any) -> Any:
"""Recursively encode type information for nested structures of
lists/dicts."""
if obj is None:
return None
if type(obj) is list:
return [_encode_type_info_recursive(item) for item in obj]
if type(obj) is dict:
return {k: _encode_type_info_recursive(v) for k, v in obj.items()}
return _typestr(obj)
def _decode_type_info_recursive(
type_info: Any, data: Any, convert_fn: Callable[[Sequence[str], Any],
Any]) -> Any:
"""Recursively decode type information for nested structures of
lists/dicts."""
if type_info is None:
return data
if isinstance(type_info, dict):
assert isinstance(data, dict)
return {
k: _decode_type_info_recursive(type_info[k], data[k], convert_fn)
for k in type_info
}
if isinstance(type_info, list) and (
# Exclude serialized tensors/numpy arrays.
len(type_info) != 2 or not isinstance(type_info[0], str)):
assert isinstance(data, list)
return [
_decode_type_info_recursive(ti, d, convert_fn)
for ti, d in zip(type_info, data)
]
return convert_fn(type_info, data)
class MsgpackEncoder:
"""Encoder with custom torch tensor and numpy array serialization.
@ -129,12 +165,10 @@ class MsgpackEncoder:
result = obj.result
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
return None, result
# Since utility results are not strongly typed, we also encode
# the type (or a list of types in the case it's a list) to
# help with correct msgspec deserialization.
return _typestr(result) if type(result) is not list else [
_typestr(v) for v in result
], result
# Since utility results are not strongly typed, we recursively
# encode type information for nested structures of lists/dicts
# to help with correct msgspec deserialization.
return _encode_type_info_recursive(result), result
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
raise TypeError(f"Object of type {type(obj)} is not serializable"
@ -288,15 +322,9 @@ class MsgpackDecoder:
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
raise TypeError("VLLM_ALLOW_INSECURE_SERIALIZATION must "
"be set to use custom utility result types")
assert isinstance(result_type, list)
if len(result_type) == 2 and isinstance(result_type[0], str):
result = self._convert_result(result_type, result)
else:
assert isinstance(result, list)
result = [
self._convert_result(rt, r)
for rt, r in zip(result_type, result)
]
# Use recursive decoding to handle nested structures
result = _decode_type_info_recursive(result_type, result,
self._convert_result)
return UtilityResult(result)
def _convert_result(self, result_type: Sequence[str], result: Any) -> Any: