[Frontend] Add unix domain socket support (#18097)

Signed-off-by: <yyweiss@gmail.com>
Signed-off-by: yyw <yyweiss@gmail.com>
This commit is contained in:
yyweiss 2025-08-09 02:23:44 +03:00 committed by GitHub
parent 2fcf6b27b6
commit baece8c3d2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 86 additions and 16 deletions

View File

@ -29,6 +29,9 @@ Start the vLLM OpenAI Compatible API server.
# Specify the port # Specify the port
vllm serve meta-llama/Llama-2-7b-hf --port 8100 vllm serve meta-llama/Llama-2-7b-hf --port 8100
# Serve over a Unix domain socket
vllm serve meta-llama/Llama-2-7b-hf --uds /tmp/vllm.sock
# Check with --help for more options # Check with --help for more options
# To list all groups # To list all groups
vllm serve --help=listgroup vllm serve --help=listgroup

View File

@ -0,0 +1,43 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from tempfile import TemporaryDirectory
import httpx
import pytest
from vllm.version import __version__ as VLLM_VERSION
from ...utils import RemoteOpenAIServer
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
@pytest.fixture(scope="module")
def server():
with TemporaryDirectory() as tmpdir:
args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-model-len",
"8192",
"--enforce-eager",
"--max-num-seqs",
"128",
"--uds",
f"{tmpdir}/vllm.sock",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest.mark.asyncio
async def test_show_version(server: RemoteOpenAIServer):
transport = httpx.HTTPTransport(uds=server.uds)
client = httpx.Client(transport=transport)
response = client.get(server.url_for("version"))
response.raise_for_status()
assert response.json() == {"version": VLLM_VERSION}

View File

@ -17,6 +17,7 @@ from pathlib import Path
from typing import Any, Callable, Literal, Optional, Union from typing import Any, Callable, Literal, Optional, Union
import cloudpickle import cloudpickle
import httpx
import openai import openai
import pytest import pytest
import requests import requests
@ -88,6 +89,8 @@ class RemoteOpenAIServer:
raise ValueError("You have manually specified the port " raise ValueError("You have manually specified the port "
"when `auto_port=True`.") "when `auto_port=True`.")
# No need for a port if using unix sockets
if "--uds" not in vllm_serve_args:
# Don't mutate the input args # Don't mutate the input args
vllm_serve_args = vllm_serve_args + [ vllm_serve_args = vllm_serve_args + [
"--port", str(get_open_port()) "--port", str(get_open_port())
@ -104,6 +107,11 @@ class RemoteOpenAIServer:
subparsers = parser.add_subparsers(required=False, dest="subparser") subparsers = parser.add_subparsers(required=False, dest="subparser")
parser = ServeSubcommand().subparser_init(subparsers) parser = ServeSubcommand().subparser_init(subparsers)
args = parser.parse_args(["--model", model, *vllm_serve_args]) args = parser.parse_args(["--model", model, *vllm_serve_args])
self.uds = args.uds
if args.uds:
self.host = None
self.port = None
else:
self.host = str(args.host or 'localhost') self.host = str(args.host or 'localhost')
self.port = int(args.port) self.port = int(args.port)
@ -150,9 +158,11 @@ class RemoteOpenAIServer:
def _wait_for_server(self, *, url: str, timeout: float): def _wait_for_server(self, *, url: str, timeout: float):
# run health check # run health check
start = time.time() start = time.time()
client = (httpx.Client(transport=httpx.HTTPTransport(
uds=self.uds)) if self.uds else requests)
while True: while True:
try: try:
if requests.get(url).status_code == 200: if client.get(url).status_code == 200:
break break
except Exception: except Exception:
# this exception can only be raised by requests.get, # this exception can only be raised by requests.get,
@ -170,7 +180,8 @@ class RemoteOpenAIServer:
@property @property
def url_root(self) -> str: def url_root(self) -> str:
return f"http://{self.host}:{self.port}" return (f"http://{self.uds.split('/')[-1]}"
if self.uds else f"http://{self.host}:{self.port}")
def url_for(self, *parts: str) -> str: def url_for(self, *parts: str) -> str:
return self.url_root + "/" + "/".join(parts) return self.url_root + "/" + "/".join(parts)

View File

@ -1777,6 +1777,12 @@ def create_server_socket(addr: tuple[str, int]) -> socket.socket:
return sock return sock
def create_server_unix_socket(path: str) -> socket.socket:
sock = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM)
sock.bind(path)
return sock
def validate_api_server_args(args): def validate_api_server_args(args):
valid_tool_parses = ToolParserManager.tool_parsers.keys() valid_tool_parses = ToolParserManager.tool_parsers.keys()
if args.enable_auto_tool_choice \ if args.enable_auto_tool_choice \
@ -1807,6 +1813,9 @@ def setup_server(args):
# workaround to make sure that we bind the port before the engine is set up. # workaround to make sure that we bind the port before the engine is set up.
# This avoids race conditions with ray. # This avoids race conditions with ray.
# see https://github.com/vllm-project/vllm/issues/8204 # see https://github.com/vllm-project/vllm/issues/8204
if args.uds:
sock = create_server_unix_socket(args.uds)
else:
sock_addr = (args.host or "", args.port) sock_addr = (args.host or "", args.port)
sock = create_server_socket(sock_addr) sock = create_server_socket(sock_addr)
@ -1820,12 +1829,14 @@ def setup_server(args):
signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGTERM, signal_handler)
if args.uds:
listen_address = f"unix:{args.uds}"
else:
addr, port = sock_addr addr, port = sock_addr
is_ssl = args.ssl_keyfile and args.ssl_certfile is_ssl = args.ssl_keyfile and args.ssl_certfile
host_part = f"[{addr}]" if is_valid_ipv6_address( host_part = f"[{addr}]" if is_valid_ipv6_address(
addr) else addr or "0.0.0.0" addr) else addr or "0.0.0.0"
listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}" listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}"
return listen_address, sock return listen_address, sock

View File

@ -72,6 +72,8 @@ class FrontendArgs:
"""Host name.""" """Host name."""
port: int = 8000 port: int = 8000
"""Port number.""" """Port number."""
uds: Optional[str] = None
"""Unix domain socket path. If set, host and port arguments are ignored."""
uvicorn_log_level: Literal["debug", "info", "warning", "error", "critical", uvicorn_log_level: Literal["debug", "info", "warning", "error", "critical",
"trace"] = "info" "trace"] = "info"
"""Log level for uvicorn.""" """Log level for uvicorn."""