[Frontend] Add /collective_rpc API endpoint (#23075)

Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
22quinn 2025-08-19 10:29:32 -07:00 committed by GitHub
parent 03d4235fd2
commit f7cf5b512e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 126 additions and 1 deletions

View File

@ -126,7 +126,8 @@ steps:
- tests/entrypoints/test_chat_utils
commands:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/
- PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/openai/test_collective_rpc.py # PYTHONPATH is needed to import custom Worker extension
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_collective_rpc.py
- pytest -v -s entrypoints/test_chat_utils.py
- label: Distributed Tests (4 GPUs) # 10min

View File

@ -0,0 +1,88 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import pytest
import requests
from tests.utils import RemoteOpenAIServer
MODEL_NAME = "Qwen/Qwen3-0.6B"
class TestWorkerExtension:
def get_model_name(self) -> str:
"""Test non-pydantic return type."""
return MODEL_NAME
def echo_args_kwargs(self, *args, **kwargs) -> dict[str, Any]:
"""Echo back both args and kwargs."""
return dict(
args=list(args),
kwargs=kwargs,
total_items=len(args) + len(kwargs),
)
def return_none(self, *args, **kwargs) -> None:
"""Test method that does not return anything"""
return
@pytest.fixture(scope="module")
def server():
args = [
"--max-model-len",
"8192",
"--max-num-seqs",
"128",
"--worker-extension-cls",
"tests.entrypoints.openai.test_collective_rpc.TestWorkerExtension",
]
with RemoteOpenAIServer(
MODEL_NAME,
args,
env_dict={
"VLLM_SERVER_DEV_MODE": "1",
"CUDA_VISIBLE_DEVICES": "0"
},
) as remote_server:
yield remote_server
def test_get_model_name(server):
"""Test basic response"""
response = requests.post(server.url_for("collective_rpc"),
json={"method": "get_model_name"})
assert response.status_code == 200
results = response.json()
assert "results" in results
assert results["results"] == [MODEL_NAME]
def test_return_none(server):
"""Test return none"""
response = requests.post(server.url_for("collective_rpc"),
json={"method": "return_none"})
assert response.status_code == 200
results = response.json()
assert results["results"] == [None]
def test_echo_args_kwargs(server):
"""Test args, kwargs, and dict response"""
args = ["arg1", "arg2"]
kwargs = {"key1": "value1", "key2": "value2"}
response = requests.post(server.url_for("collective_rpc"),
json={
"method": "echo_args_kwargs",
"args": args,
"kwargs": kwargs
})
assert response.status_code == 200
results = response.json()
result = results["results"][0]
assert result["args"] == args
assert result["kwargs"] == kwargs
assert result["total_items"] == len(args) + len(kwargs)

View File

@ -329,3 +329,11 @@ class EngineClient(ABC):
drain_timeout: int = 300) -> None:
"""Scale the engine"""
raise NotImplementedError
async def collective_rpc(self,
method: str,
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict] = None):
"""Perform a collective RPC call to the given path."""
raise NotImplementedError

View File

@ -1044,6 +1044,34 @@ if envs.VLLM_SERVER_DEV_MODE:
is_sleeping = await engine_client(raw_request).is_sleeping()
return JSONResponse(content={"is_sleeping": is_sleeping})
@router.post("/collective_rpc")
async def collective_rpc(raw_request: Request):
try:
body = await raw_request.json()
except json.JSONDecodeError as e:
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value,
detail=f"JSON decode error: {e}") from e
method = body.get("method")
if method is None:
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value,
detail="Missing 'method' in request body")
# For security reason, only serialized string args/kwargs are passed.
# User-defined `method` is responsible for deseralization if needed.
args: list[str] = body.get("args", [])
kwargs: dict[str, str] = body.get("kwargs", {})
timeout: Optional[float] = body.get("timeout")
results = await engine_client(raw_request).collective_rpc(
method=method, timeout=timeout, args=tuple(args), kwargs=kwargs)
if results is None:
return Response(status_code=200)
response: list[Any] = []
for result in results:
if result is None or isinstance(result, (dict, list)):
response.append(result)
else:
response.append(str(result))
return JSONResponse(content={"results": response})
@router.post("/scale_elastic_ep",
dependencies=[Depends(validate_json_request)],