mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:25:01 +08:00
[Frontend] Add /collective_rpc API endpoint (#23075)
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
parent
03d4235fd2
commit
f7cf5b512e
@ -126,7 +126,8 @@ steps:
|
||||
- tests/entrypoints/test_chat_utils
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/
|
||||
- PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/openai/test_collective_rpc.py # PYTHONPATH is needed to import custom Worker extension
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_collective_rpc.py
|
||||
- pytest -v -s entrypoints/test_chat_utils.py
|
||||
|
||||
- label: Distributed Tests (4 GPUs) # 10min
|
||||
|
||||
88
tests/entrypoints/openai/test_collective_rpc.py
Normal file
88
tests/entrypoints/openai/test_collective_rpc.py
Normal file
@ -0,0 +1,88 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen3-0.6B"
|
||||
|
||||
|
||||
class TestWorkerExtension:
|
||||
|
||||
def get_model_name(self) -> str:
|
||||
"""Test non-pydantic return type."""
|
||||
return MODEL_NAME
|
||||
|
||||
def echo_args_kwargs(self, *args, **kwargs) -> dict[str, Any]:
|
||||
"""Echo back both args and kwargs."""
|
||||
return dict(
|
||||
args=list(args),
|
||||
kwargs=kwargs,
|
||||
total_items=len(args) + len(kwargs),
|
||||
)
|
||||
|
||||
def return_none(self, *args, **kwargs) -> None:
|
||||
"""Test method that does not return anything"""
|
||||
return
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"--max-num-seqs",
|
||||
"128",
|
||||
"--worker-extension-cls",
|
||||
"tests.entrypoints.openai.test_collective_rpc.TestWorkerExtension",
|
||||
]
|
||||
with RemoteOpenAIServer(
|
||||
MODEL_NAME,
|
||||
args,
|
||||
env_dict={
|
||||
"VLLM_SERVER_DEV_MODE": "1",
|
||||
"CUDA_VISIBLE_DEVICES": "0"
|
||||
},
|
||||
) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
def test_get_model_name(server):
|
||||
"""Test basic response"""
|
||||
response = requests.post(server.url_for("collective_rpc"),
|
||||
json={"method": "get_model_name"})
|
||||
assert response.status_code == 200
|
||||
results = response.json()
|
||||
assert "results" in results
|
||||
assert results["results"] == [MODEL_NAME]
|
||||
|
||||
|
||||
def test_return_none(server):
|
||||
"""Test return none"""
|
||||
response = requests.post(server.url_for("collective_rpc"),
|
||||
json={"method": "return_none"})
|
||||
assert response.status_code == 200
|
||||
results = response.json()
|
||||
assert results["results"] == [None]
|
||||
|
||||
|
||||
def test_echo_args_kwargs(server):
|
||||
"""Test args, kwargs, and dict response"""
|
||||
args = ["arg1", "arg2"]
|
||||
kwargs = {"key1": "value1", "key2": "value2"}
|
||||
response = requests.post(server.url_for("collective_rpc"),
|
||||
json={
|
||||
"method": "echo_args_kwargs",
|
||||
"args": args,
|
||||
"kwargs": kwargs
|
||||
})
|
||||
assert response.status_code == 200
|
||||
results = response.json()
|
||||
result = results["results"][0]
|
||||
assert result["args"] == args
|
||||
assert result["kwargs"] == kwargs
|
||||
assert result["total_items"] == len(args) + len(kwargs)
|
||||
@ -329,3 +329,11 @@ class EngineClient(ABC):
|
||||
drain_timeout: int = 300) -> None:
|
||||
"""Scale the engine"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def collective_rpc(self,
|
||||
method: str,
|
||||
timeout: Optional[float] = None,
|
||||
args: tuple = (),
|
||||
kwargs: Optional[dict] = None):
|
||||
"""Perform a collective RPC call to the given path."""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -1044,6 +1044,34 @@ if envs.VLLM_SERVER_DEV_MODE:
|
||||
is_sleeping = await engine_client(raw_request).is_sleeping()
|
||||
return JSONResponse(content={"is_sleeping": is_sleeping})
|
||||
|
||||
@router.post("/collective_rpc")
|
||||
async def collective_rpc(raw_request: Request):
|
||||
try:
|
||||
body = await raw_request.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value,
|
||||
detail=f"JSON decode error: {e}") from e
|
||||
method = body.get("method")
|
||||
if method is None:
|
||||
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value,
|
||||
detail="Missing 'method' in request body")
|
||||
# For security reason, only serialized string args/kwargs are passed.
|
||||
# User-defined `method` is responsible for deseralization if needed.
|
||||
args: list[str] = body.get("args", [])
|
||||
kwargs: dict[str, str] = body.get("kwargs", {})
|
||||
timeout: Optional[float] = body.get("timeout")
|
||||
results = await engine_client(raw_request).collective_rpc(
|
||||
method=method, timeout=timeout, args=tuple(args), kwargs=kwargs)
|
||||
if results is None:
|
||||
return Response(status_code=200)
|
||||
response: list[Any] = []
|
||||
for result in results:
|
||||
if result is None or isinstance(result, (dict, list)):
|
||||
response.append(result)
|
||||
else:
|
||||
response.append(str(result))
|
||||
return JSONResponse(content={"results": response})
|
||||
|
||||
|
||||
@router.post("/scale_elastic_ep",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user