diff --git a/docs/cli/README.md b/docs/cli/README.md index b1371c82a4c4..a7de6d7192ac 100644 --- a/docs/cli/README.md +++ b/docs/cli/README.md @@ -29,6 +29,9 @@ Start the vLLM OpenAI Compatible API server. # Specify the port vllm serve meta-llama/Llama-2-7b-hf --port 8100 + # Serve over a Unix domain socket + vllm serve meta-llama/Llama-2-7b-hf --uds /tmp/vllm.sock + # Check with --help for more options # To list all groups vllm serve --help=listgroup diff --git a/tests/entrypoints/openai/test_uds.py b/tests/entrypoints/openai/test_uds.py new file mode 100644 index 000000000000..5c39869a794f --- /dev/null +++ b/tests/entrypoints/openai/test_uds.py @@ -0,0 +1,43 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from tempfile import TemporaryDirectory + +import httpx +import pytest + +from vllm.version import __version__ as VLLM_VERSION + +from ...utils import RemoteOpenAIServer + +MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" + + +@pytest.fixture(scope="module") +def server(): + with TemporaryDirectory() as tmpdir: + args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "8192", + "--enforce-eager", + "--max-num-seqs", + "128", + "--uds", + f"{tmpdir}/vllm.sock", + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest.mark.asyncio +async def test_show_version(server: RemoteOpenAIServer): + transport = httpx.HTTPTransport(uds=server.uds) + client = httpx.Client(transport=transport) + response = client.get(server.url_for("version")) + response.raise_for_status() + + assert response.json() == {"version": VLLM_VERSION} diff --git a/tests/utils.py b/tests/utils.py index 741b4401cc21..18fcde949160 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -17,6 +17,7 @@ from pathlib import Path from typing import Any, Callable, Literal, Optional, Union import cloudpickle +import httpx import openai import pytest import requests @@ -88,10 +89,12 @@ class RemoteOpenAIServer: raise ValueError("You have manually specified the port " "when `auto_port=True`.") - # Don't mutate the input args - vllm_serve_args = vllm_serve_args + [ - "--port", str(get_open_port()) - ] + # No need for a port if using unix sockets + if "--uds" not in vllm_serve_args: + # Don't mutate the input args + vllm_serve_args = vllm_serve_args + [ + "--port", str(get_open_port()) + ] if seed is not None: if "--seed" in vllm_serve_args: raise ValueError("You have manually specified the seed " @@ -104,8 +107,13 @@ class RemoteOpenAIServer: subparsers = parser.add_subparsers(required=False, dest="subparser") parser = ServeSubcommand().subparser_init(subparsers) args = parser.parse_args(["--model", model, *vllm_serve_args]) - self.host = str(args.host or 'localhost') - self.port = int(args.port) + self.uds = args.uds + if args.uds: + self.host = None + self.port = None + else: + self.host = str(args.host or 'localhost') + self.port = int(args.port) self.show_hidden_metrics = \ args.show_hidden_metrics_for_version is not None @@ -150,9 +158,11 @@ class RemoteOpenAIServer: def _wait_for_server(self, *, url: str, timeout: float): # run health check start = time.time() + client = (httpx.Client(transport=httpx.HTTPTransport( + uds=self.uds)) if self.uds else requests) while True: try: - if requests.get(url).status_code == 200: + if client.get(url).status_code == 200: break except Exception: # this exception can only be raised by requests.get, @@ -170,7 +180,8 @@ class RemoteOpenAIServer: @property def url_root(self) -> str: - return f"http://{self.host}:{self.port}" + return (f"http://{self.uds.split('/')[-1]}" + if self.uds else f"http://{self.host}:{self.port}") def url_for(self, *parts: str) -> str: return self.url_root + "/" + "/".join(parts) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 00eaba8c872f..e5d31c1fd03f 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1777,6 +1777,12 @@ def create_server_socket(addr: tuple[str, int]) -> socket.socket: return sock +def create_server_unix_socket(path: str) -> socket.socket: + sock = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM) + sock.bind(path) + return sock + + def validate_api_server_args(args): valid_tool_parses = ToolParserManager.tool_parsers.keys() if args.enable_auto_tool_choice \ @@ -1807,8 +1813,11 @@ def setup_server(args): # workaround to make sure that we bind the port before the engine is set up. # This avoids race conditions with ray. # see https://github.com/vllm-project/vllm/issues/8204 - sock_addr = (args.host or "", args.port) - sock = create_server_socket(sock_addr) + if args.uds: + sock = create_server_unix_socket(args.uds) + else: + sock_addr = (args.host or "", args.port) + sock = create_server_socket(sock_addr) # workaround to avoid footguns where uvicorn drops requests with too # many concurrent requests active @@ -1820,12 +1829,14 @@ def setup_server(args): signal.signal(signal.SIGTERM, signal_handler) - addr, port = sock_addr - is_ssl = args.ssl_keyfile and args.ssl_certfile - host_part = f"[{addr}]" if is_valid_ipv6_address( - addr) else addr or "0.0.0.0" - listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}" - + if args.uds: + listen_address = f"unix:{args.uds}" + else: + addr, port = sock_addr + is_ssl = args.ssl_keyfile and args.ssl_certfile + host_part = f"[{addr}]" if is_valid_ipv6_address( + addr) else addr or "0.0.0.0" + listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}" return listen_address, sock diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index e89463a03cda..e15f65b43082 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -72,6 +72,8 @@ class FrontendArgs: """Host name.""" port: int = 8000 """Port number.""" + uds: Optional[str] = None + """Unix domain socket path. If set, host and port arguments are ignored.""" uvicorn_log_level: Literal["debug", "info", "warning", "error", "critical", "trace"] = "info" """Log level for uvicorn."""