From 6e244ae09121b2c1cdcd4db51076decc4a724c5c Mon Sep 17 00:00:00 2001 From: Yazan Sharaya <97323283+Yazan-Sharaya@users.noreply.github.com> Date: Fri, 27 Jun 2025 07:44:14 +0300 Subject: [PATCH] [Perf][Frontend] eliminate api_key and x_request_id headers middleware overhead (#19946) Signed-off-by: Yazan-Sharaya --- docs/serving/openai_compatible_server.md | 5 - .../openai/test_optional_middleware.py | 116 ++++++++++++++++++ vllm/entrypoints/openai/api_server.py | 100 +++++++++++---- vllm/entrypoints/openai/cli_args.py | 2 +- 4 files changed, 190 insertions(+), 33 deletions(-) create mode 100644 tests/entrypoints/openai/test_optional_middleware.py diff --git a/docs/serving/openai_compatible_server.md b/docs/serving/openai_compatible_server.md index 00756e719992..a3f1ef9fd8b6 100644 --- a/docs/serving/openai_compatible_server.md +++ b/docs/serving/openai_compatible_server.md @@ -146,11 +146,6 @@ completion = client.chat.completions.create( Only `X-Request-Id` HTTP request header is supported for now. It can be enabled with `--enable-request-id-headers`. -> Note that enablement of the headers can impact performance significantly at high QPS -> rates. We recommend implementing HTTP headers at the router level (e.g. via Istio), -> rather than within the vLLM layer for this reason. -> See [this PR](https://github.com/vllm-project/vllm/pull/11529) for more details. - ??? Code ```python diff --git a/tests/entrypoints/openai/test_optional_middleware.py b/tests/entrypoints/openai/test_optional_middleware.py new file mode 100644 index 000000000000..882fa0886ce3 --- /dev/null +++ b/tests/entrypoints/openai/test_optional_middleware.py @@ -0,0 +1,116 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Tests for middleware that's off by default and can be toggled through +server arguments, mainly --api-key and --enable-request-id-headers. +""" + +from http import HTTPStatus + +import pytest +import requests + +from ...utils import RemoteOpenAIServer + +# Use a small embeddings model for faster startup and smaller memory footprint. +# Since we are not testing any chat functionality, +# using a chat capable model is overkill. +MODEL_NAME = "intfloat/multilingual-e5-small" + + +@pytest.fixture(scope="module") +def server(request: pytest.FixtureRequest): + passed_params = [] + if hasattr(request, "param"): + passed_params = request.param + if isinstance(passed_params, str): + passed_params = [passed_params] + + args = [ + "--task", + "embed", + # use half precision for speed and memory savings in CI environment + "--dtype", + "float16", + "--max-model-len", + "512", + "--enforce-eager", + "--max-num-seqs", + "2", + *passed_params + ] + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest.mark.asyncio +async def test_no_api_token(server: RemoteOpenAIServer): + response = requests.get(server.url_for("v1/models")) + assert response.status_code == HTTPStatus.OK + + +@pytest.mark.asyncio +async def test_no_request_id_header(server: RemoteOpenAIServer): + response = requests.get(server.url_for("health")) + assert "X-Request-Id" not in response.headers + + +@pytest.mark.parametrize( + "server", + [["--api-key", "test"]], + indirect=True, +) +@pytest.mark.asyncio +async def test_missing_api_token(server: RemoteOpenAIServer): + response = requests.get(server.url_for("v1/models")) + assert response.status_code == HTTPStatus.UNAUTHORIZED + + +@pytest.mark.parametrize( + "server", + [["--api-key", "test"]], + indirect=True, +) +@pytest.mark.asyncio +async def test_passed_api_token(server: RemoteOpenAIServer): + response = requests.get(server.url_for("v1/models"), + headers={"Authorization": "Bearer test"}) + assert response.status_code == HTTPStatus.OK + + +@pytest.mark.parametrize( + "server", + [["--api-key", "test"]], + indirect=True, +) +@pytest.mark.asyncio +async def test_not_v1_api_token(server: RemoteOpenAIServer): + # Authorization check is skipped for any paths that + # don't start with /v1 (e.g. /v1/chat/completions). + response = requests.get(server.url_for("health")) + assert response.status_code == HTTPStatus.OK + + +@pytest.mark.parametrize( + "server", + ["--enable-request-id-headers"], + indirect=True, +) +@pytest.mark.asyncio +async def test_enable_request_id_header(server: RemoteOpenAIServer): + response = requests.get(server.url_for("health")) + assert "X-Request-Id" in response.headers + assert len(response.headers.get("X-Request-Id", "")) == 32 + + +@pytest.mark.parametrize( + "server", + ["--enable-request-id-headers"], + indirect=True, +) +@pytest.mark.asyncio +async def test_custom_request_id_header(server: RemoteOpenAIServer): + response = requests.get(server.url_for("health"), + headers={"X-Request-Id": "Custom"}) + assert "X-Request-Id" in response.headers + assert response.headers.get("X-Request-Id") == "Custom" diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 681633a2aff7..f3fd15486271 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -14,7 +14,7 @@ import socket import tempfile import uuid from argparse import Namespace -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Awaitable from contextlib import asynccontextmanager from functools import partial from http import HTTPStatus @@ -30,8 +30,9 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse from prometheus_client import make_asgi_app from prometheus_fastapi_instrumentator import Instrumentator from starlette.concurrency import iterate_in_threadpool -from starlette.datastructures import State +from starlette.datastructures import URL, Headers, MutableHeaders, State from starlette.routing import Mount +from starlette.types import ASGIApp, Message, Receive, Scope, Send from typing_extensions import assert_never import vllm.envs as envs @@ -1061,6 +1062,74 @@ def load_log_config(log_config_file: Optional[str]) -> Optional[dict]: return None +class AuthenticationMiddleware: + """ + Pure ASGI middleware that authenticates each request by checking + if the Authorization header exists and equals "Bearer {api_key}". + + Notes + ----- + There are two cases in which authentication is skipped: + 1. The HTTP method is OPTIONS. + 2. The request path doesn't start with /v1 (e.g. /health). + """ + + def __init__(self, app: ASGIApp, api_token: str) -> None: + self.app = app + self.api_token = api_token + + def __call__(self, scope: Scope, receive: Receive, + send: Send) -> Awaitable[None]: + if scope["type"] not in ("http", + "websocket") or scope["method"] == "OPTIONS": + # scope["type"] can be "lifespan" or "startup" for example, + # in which case we don't need to do anything + return self.app(scope, receive, send) + root_path = scope.get("root_path", "") + url_path = URL(scope=scope).path.removeprefix(root_path) + headers = Headers(scope=scope) + # Type narrow to satisfy mypy. + if url_path.startswith("/v1") and headers.get( + "Authorization") != f"Bearer {self.api_token}": + response = JSONResponse(content={"error": "Unauthorized"}, + status_code=401) + return response(scope, receive, send) + return self.app(scope, receive, send) + + +class XRequestIdMiddleware: + """ + Middleware the set's the X-Request-Id header for each response + to a random uuid4 (hex) value if the header isn't already + present in the request, otherwise use the provided request id. + """ + + def __init__(self, app: ASGIApp) -> None: + self.app = app + + def __call__(self, scope: Scope, receive: Receive, + send: Send) -> Awaitable[None]: + if scope["type"] not in ("http", "websocket"): + return self.app(scope, receive, send) + + # Extract the request headers. + request_headers = Headers(scope=scope) + + async def send_with_request_id(message: Message) -> None: + """ + Custom send function to mutate the response headers + and append X-Request-Id to it. + """ + if message["type"] == "http.response.start": + response_headers = MutableHeaders(raw=message["headers"]) + request_id = request_headers.get("X-Request-Id", + uuid.uuid4().hex) + response_headers.append("X-Request-Id", request_id) + await send(message) + + return self.app(scope, receive, send_with_request_id) + + def build_app(args: Namespace) -> FastAPI: if args.disable_fastapi_docs: app = FastAPI(openapi_url=None, @@ -1108,33 +1177,10 @@ def build_app(args: Namespace) -> FastAPI: # Ensure --api-key option from CLI takes precedence over VLLM_API_KEY if token := args.api_key or envs.VLLM_API_KEY: - - @app.middleware("http") - async def authentication(request: Request, call_next): - if request.method == "OPTIONS": - return await call_next(request) - url_path = request.url.path - if app.root_path and url_path.startswith(app.root_path): - url_path = url_path[len(app.root_path):] - if not url_path.startswith("/v1"): - return await call_next(request) - if request.headers.get("Authorization") != "Bearer " + token: - return JSONResponse(content={"error": "Unauthorized"}, - status_code=401) - return await call_next(request) + app.add_middleware(AuthenticationMiddleware, api_token=token) if args.enable_request_id_headers: - logger.warning( - "CAUTION: Enabling X-Request-Id headers in the API Server. " - "This can harm performance at high QPS.") - - @app.middleware("http") - async def add_request_id(request: Request, call_next): - request_id = request.headers.get( - "X-Request-Id") or uuid.uuid4().hex - response = await call_next(request) - response.headers["X-Request-Id"] = request_id - return response + app.add_middleware(XRequestIdMiddleware) if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE: logger.warning("CAUTION: Enabling log response in the API Server. " diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index dd4bd53046a3..f9bec8451868 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -216,7 +216,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--enable-request-id-headers", action="store_true", help="If specified, API server will add X-Request-Id header to " - "responses. Caution: this hurts performance at high QPS.") + "responses.") parser.add_argument( "--enable-auto-tool-choice", action="store_true",