mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 20:35:01 +08:00
[Frontend] Add /collective_rpc API endpoint (#23075)
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
parent
03d4235fd2
commit
f7cf5b512e
@ -126,7 +126,8 @@ steps:
|
|||||||
- tests/entrypoints/test_chat_utils
|
- tests/entrypoints/test_chat_utils
|
||||||
commands:
|
commands:
|
||||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/
|
- PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/openai/test_collective_rpc.py # PYTHONPATH is needed to import custom Worker extension
|
||||||
|
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_collective_rpc.py
|
||||||
- pytest -v -s entrypoints/test_chat_utils.py
|
- pytest -v -s entrypoints/test_chat_utils.py
|
||||||
|
|
||||||
- label: Distributed Tests (4 GPUs) # 10min
|
- label: Distributed Tests (4 GPUs) # 10min
|
||||||
|
|||||||
88
tests/entrypoints/openai/test_collective_rpc.py
Normal file
88
tests/entrypoints/openai/test_collective_rpc.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from tests.utils import RemoteOpenAIServer
|
||||||
|
|
||||||
|
MODEL_NAME = "Qwen/Qwen3-0.6B"
|
||||||
|
|
||||||
|
|
||||||
|
class TestWorkerExtension:
|
||||||
|
|
||||||
|
def get_model_name(self) -> str:
|
||||||
|
"""Test non-pydantic return type."""
|
||||||
|
return MODEL_NAME
|
||||||
|
|
||||||
|
def echo_args_kwargs(self, *args, **kwargs) -> dict[str, Any]:
|
||||||
|
"""Echo back both args and kwargs."""
|
||||||
|
return dict(
|
||||||
|
args=list(args),
|
||||||
|
kwargs=kwargs,
|
||||||
|
total_items=len(args) + len(kwargs),
|
||||||
|
)
|
||||||
|
|
||||||
|
def return_none(self, *args, **kwargs) -> None:
|
||||||
|
"""Test method that does not return anything"""
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def server():
|
||||||
|
args = [
|
||||||
|
"--max-model-len",
|
||||||
|
"8192",
|
||||||
|
"--max-num-seqs",
|
||||||
|
"128",
|
||||||
|
"--worker-extension-cls",
|
||||||
|
"tests.entrypoints.openai.test_collective_rpc.TestWorkerExtension",
|
||||||
|
]
|
||||||
|
with RemoteOpenAIServer(
|
||||||
|
MODEL_NAME,
|
||||||
|
args,
|
||||||
|
env_dict={
|
||||||
|
"VLLM_SERVER_DEV_MODE": "1",
|
||||||
|
"CUDA_VISIBLE_DEVICES": "0"
|
||||||
|
},
|
||||||
|
) as remote_server:
|
||||||
|
yield remote_server
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_model_name(server):
|
||||||
|
"""Test basic response"""
|
||||||
|
response = requests.post(server.url_for("collective_rpc"),
|
||||||
|
json={"method": "get_model_name"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
results = response.json()
|
||||||
|
assert "results" in results
|
||||||
|
assert results["results"] == [MODEL_NAME]
|
||||||
|
|
||||||
|
|
||||||
|
def test_return_none(server):
|
||||||
|
"""Test return none"""
|
||||||
|
response = requests.post(server.url_for("collective_rpc"),
|
||||||
|
json={"method": "return_none"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
results = response.json()
|
||||||
|
assert results["results"] == [None]
|
||||||
|
|
||||||
|
|
||||||
|
def test_echo_args_kwargs(server):
|
||||||
|
"""Test args, kwargs, and dict response"""
|
||||||
|
args = ["arg1", "arg2"]
|
||||||
|
kwargs = {"key1": "value1", "key2": "value2"}
|
||||||
|
response = requests.post(server.url_for("collective_rpc"),
|
||||||
|
json={
|
||||||
|
"method": "echo_args_kwargs",
|
||||||
|
"args": args,
|
||||||
|
"kwargs": kwargs
|
||||||
|
})
|
||||||
|
assert response.status_code == 200
|
||||||
|
results = response.json()
|
||||||
|
result = results["results"][0]
|
||||||
|
assert result["args"] == args
|
||||||
|
assert result["kwargs"] == kwargs
|
||||||
|
assert result["total_items"] == len(args) + len(kwargs)
|
||||||
@ -329,3 +329,11 @@ class EngineClient(ABC):
|
|||||||
drain_timeout: int = 300) -> None:
|
drain_timeout: int = 300) -> None:
|
||||||
"""Scale the engine"""
|
"""Scale the engine"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def collective_rpc(self,
|
||||||
|
method: str,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
args: tuple = (),
|
||||||
|
kwargs: Optional[dict] = None):
|
||||||
|
"""Perform a collective RPC call to the given path."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|||||||
@ -1044,6 +1044,34 @@ if envs.VLLM_SERVER_DEV_MODE:
|
|||||||
is_sleeping = await engine_client(raw_request).is_sleeping()
|
is_sleeping = await engine_client(raw_request).is_sleeping()
|
||||||
return JSONResponse(content={"is_sleeping": is_sleeping})
|
return JSONResponse(content={"is_sleeping": is_sleeping})
|
||||||
|
|
||||||
|
@router.post("/collective_rpc")
|
||||||
|
async def collective_rpc(raw_request: Request):
|
||||||
|
try:
|
||||||
|
body = await raw_request.json()
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value,
|
||||||
|
detail=f"JSON decode error: {e}") from e
|
||||||
|
method = body.get("method")
|
||||||
|
if method is None:
|
||||||
|
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value,
|
||||||
|
detail="Missing 'method' in request body")
|
||||||
|
# For security reason, only serialized string args/kwargs are passed.
|
||||||
|
# User-defined `method` is responsible for deseralization if needed.
|
||||||
|
args: list[str] = body.get("args", [])
|
||||||
|
kwargs: dict[str, str] = body.get("kwargs", {})
|
||||||
|
timeout: Optional[float] = body.get("timeout")
|
||||||
|
results = await engine_client(raw_request).collective_rpc(
|
||||||
|
method=method, timeout=timeout, args=tuple(args), kwargs=kwargs)
|
||||||
|
if results is None:
|
||||||
|
return Response(status_code=200)
|
||||||
|
response: list[Any] = []
|
||||||
|
for result in results:
|
||||||
|
if result is None or isinstance(result, (dict, list)):
|
||||||
|
response.append(result)
|
||||||
|
else:
|
||||||
|
response.append(str(result))
|
||||||
|
return JSONResponse(content={"results": response})
|
||||||
|
|
||||||
|
|
||||||
@router.post("/scale_elastic_ep",
|
@router.post("/scale_elastic_ep",
|
||||||
dependencies=[Depends(validate_json_request)],
|
dependencies=[Depends(validate_json_request)],
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user