mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-09 06:35:42 +08:00
[Bug]: Authorization ignored when root_path is set (#10606)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
parent
2b0879bfc2
commit
d04b13a380
103
tests/entrypoints/openai/test_root_path.py
Normal file
103
tests/entrypoints/openai/test_root_path.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
import contextlib
|
||||||
|
import os
|
||||||
|
from typing import Any, List, NamedTuple
|
||||||
|
|
||||||
|
import openai # use the official client for correctness check
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from ...utils import RemoteOpenAIServer
|
||||||
|
|
||||||
|
# # any model with a chat template should work here
|
||||||
|
MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct"
|
||||||
|
DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501
|
||||||
|
API_KEY = "abc-123"
|
||||||
|
ERROR_API_KEY = "abc"
|
||||||
|
ROOT_PATH = "llm"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def server():
|
||||||
|
args = [
|
||||||
|
# use half precision for speed and memory savings in CI environment
|
||||||
|
"--dtype",
|
||||||
|
"float16",
|
||||||
|
"--enforce-eager",
|
||||||
|
"--max-model-len",
|
||||||
|
"4080",
|
||||||
|
"--root-path", # use --root-path=/llm for testing
|
||||||
|
"/" + ROOT_PATH,
|
||||||
|
"--chat-template",
|
||||||
|
DUMMY_CHAT_TEMPLATE,
|
||||||
|
]
|
||||||
|
envs = os.environ.copy()
|
||||||
|
|
||||||
|
envs["VLLM_API_KEY"] = API_KEY
|
||||||
|
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=envs) as remote_server:
|
||||||
|
yield remote_server
|
||||||
|
|
||||||
|
|
||||||
|
class TestCase(NamedTuple):
|
||||||
|
model_name: str
|
||||||
|
base_url: List[str]
|
||||||
|
api_key: str
|
||||||
|
expected_error: Any
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_case",
|
||||||
|
[
|
||||||
|
TestCase(
|
||||||
|
model_name=MODEL_NAME,
|
||||||
|
base_url=["v1"], # http://localhost:8000/v1
|
||||||
|
api_key=ERROR_API_KEY,
|
||||||
|
expected_error=openai.AuthenticationError),
|
||||||
|
TestCase(
|
||||||
|
model_name=MODEL_NAME,
|
||||||
|
base_url=[ROOT_PATH, "v1"], # http://localhost:8000/llm/v1
|
||||||
|
api_key=ERROR_API_KEY,
|
||||||
|
expected_error=openai.AuthenticationError),
|
||||||
|
TestCase(
|
||||||
|
model_name=MODEL_NAME,
|
||||||
|
base_url=["v1"], # http://localhost:8000/v1
|
||||||
|
api_key=API_KEY,
|
||||||
|
expected_error=None),
|
||||||
|
TestCase(
|
||||||
|
model_name=MODEL_NAME,
|
||||||
|
base_url=[ROOT_PATH, "v1"], # http://localhost:8000/llm/v1
|
||||||
|
api_key=API_KEY,
|
||||||
|
expected_error=None),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_chat_session_root_path_with_api_key(server: RemoteOpenAIServer,
|
||||||
|
test_case: TestCase):
|
||||||
|
saying: str = "Here is a common saying about apple. An apple a day, keeps"
|
||||||
|
ctx = contextlib.nullcontext()
|
||||||
|
if test_case.expected_error is not None:
|
||||||
|
ctx = pytest.raises(test_case.expected_error)
|
||||||
|
with ctx:
|
||||||
|
client = openai.AsyncOpenAI(
|
||||||
|
api_key=test_case.api_key,
|
||||||
|
base_url=server.url_for(*test_case.base_url),
|
||||||
|
max_retries=0)
|
||||||
|
chat_completion = await client.chat.completions.create(
|
||||||
|
model=test_case.model_name,
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "tell me a common saying"
|
||||||
|
}, {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": saying
|
||||||
|
}],
|
||||||
|
extra_body={
|
||||||
|
"continue_final_message": True,
|
||||||
|
"add_generation_prompt": False
|
||||||
|
})
|
||||||
|
|
||||||
|
assert chat_completion.id is not None
|
||||||
|
assert len(chat_completion.choices) == 1
|
||||||
|
choice = chat_completion.choices[0]
|
||||||
|
assert choice.finish_reason == "stop"
|
||||||
|
message = choice.message
|
||||||
|
assert len(message.content) > 0
|
||||||
|
assert message.role == "assistant"
|
||||||
@ -499,10 +499,12 @@ def build_app(args: Namespace) -> FastAPI:
|
|||||||
|
|
||||||
@app.middleware("http")
|
@app.middleware("http")
|
||||||
async def authentication(request: Request, call_next):
|
async def authentication(request: Request, call_next):
|
||||||
root_path = "" if args.root_path is None else args.root_path
|
|
||||||
if request.method == "OPTIONS":
|
if request.method == "OPTIONS":
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
if not request.url.path.startswith(f"{root_path}/v1"):
|
url_path = request.url.path
|
||||||
|
if app.root_path and url_path.startswith(app.root_path):
|
||||||
|
url_path = url_path[len(app.root_path):]
|
||||||
|
if not url_path.startswith("/v1"):
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
if request.headers.get("Authorization") != "Bearer " + token:
|
if request.headers.get("Authorization") != "Bearer " + token:
|
||||||
return JSONResponse(content={"error": "Unauthorized"},
|
return JSONResponse(content={"error": "Unauthorized"},
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user