[Feature] Simple API token authentication and pluggable middlewares (#1106)

This commit is contained in:
Erfan Al-Hossami 2024-01-23 18:13:00 -05:00 committed by GitHub
parent 7a0b011dd5
commit 9c1352eb57
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 48 additions and 29 deletions

View File

@ -63,38 +63,10 @@ Call ``llm.generate`` to generate the outputs. It adds the input prompts to vLLM
The code example can also be found in `examples/offline_inference.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference.py>`_. The code example can also be found in `examples/offline_inference.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference.py>`_.
API Server
----------
vLLM can be deployed as an LLM service. We provide an example `FastAPI <https://fastapi.tiangolo.com/>`_ server. Check `vllm/entrypoints/api_server.py <https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/api_server.py>`_ for the server implementation. The server uses ``AsyncLLMEngine`` class to support asynchronous processing of incoming requests.
Start the server:
.. code-block:: console
$ python -m vllm.entrypoints.api_server
By default, this command starts the server at ``http://localhost:8000`` with the OPT-125M model.
Query the model in shell:
.. code-block:: console
$ curl http://localhost:8000/generate \
$ -d '{
$ "prompt": "San Francisco is a",
$ "use_beam_search": true,
$ "n": 4,
$ "temperature": 0
$ }'
See `examples/api_client.py <https://github.com/vllm-project/vllm/blob/main/examples/api_client.py>`_ for a more detailed client example.
OpenAI-Compatible Server OpenAI-Compatible Server
------------------------ ------------------------
vLLM can be deployed as a server that mimics the OpenAI API protocol. This allows vLLM to be used as a drop-in replacement for applications using OpenAI API. vLLM can be deployed as a server that implements the OpenAI API protocol. This allows vLLM to be used as a drop-in replacement for applications using OpenAI API.
By default, it starts the server at ``http://localhost:8000``. You can specify the address with ``--host`` and ``--port`` arguments. The server currently hosts one model at a time (OPT-125M in the command below) and implements `list models <https://platform.openai.com/docs/api-reference/models/list>`_, `create chat completion <https://platform.openai.com/docs/api-reference/chat/completions/create>`_, and `create completion <https://platform.openai.com/docs/api-reference/completions/create>`_ endpoints. We are actively adding support for more endpoints. By default, it starts the server at ``http://localhost:8000``. You can specify the address with ``--host`` and ``--port`` arguments. The server currently hosts one model at a time (OPT-125M in the command below) and implements `list models <https://platform.openai.com/docs/api-reference/models/list>`_, `create chat completion <https://platform.openai.com/docs/api-reference/chat/completions/create>`_, and `create completion <https://platform.openai.com/docs/api-reference/completions/create>`_ endpoints. We are actively adding support for more endpoints.
Start the server: Start the server:
@ -118,6 +90,8 @@ This server can be queried in the same format as OpenAI API. For example, list t
$ curl http://localhost:8000/v1/models $ curl http://localhost:8000/v1/models
You can pass in the argument ``--api-key`` or environment variable ``VLLM_API_KEY`` to enable the server to check for API key in the header.
Using OpenAI Completions API with vLLM Using OpenAI Completions API with vLLM
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

View File

@ -2,6 +2,10 @@ import argparse
import asyncio import asyncio
import json import json
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import os
import importlib
import inspect
from aioprometheus import MetricsMiddleware from aioprometheus import MetricsMiddleware
from aioprometheus.asgi.starlette import metrics from aioprometheus.asgi.starlette import metrics
import fastapi import fastapi
@ -64,6 +68,13 @@ def parse_args():
type=json.loads, type=json.loads,
default=["*"], default=["*"],
help="allowed headers") help="allowed headers")
parser.add_argument(
"--api-key",
type=str,
default=None,
help=
"If provided, the server will require this key to be presented in the header."
)
parser.add_argument("--served-model-name", parser.add_argument("--served-model-name",
type=str, type=str,
default=None, default=None,
@ -94,6 +105,17 @@ def parse_args():
type=str, type=str,
default=None, default=None,
help="FastAPI root_path when app is behind a path based routing proxy") help="FastAPI root_path when app is behind a path based routing proxy")
parser.add_argument(
"--middleware",
type=str,
action="append",
default=[],
help="Additional ASGI middleware to apply to the app. "
"We accept multiple --middleware arguments. "
"The value should be an import path. "
"If a function is provided, vLLM will add it to the server using @app.middleware('http'). "
"If a class is provided, vLLM will add it to the server using app.add_middleware(). "
)
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
return parser.parse_args() return parser.parse_args()
@ -161,6 +183,29 @@ if __name__ == "__main__":
allow_headers=args.allowed_headers, allow_headers=args.allowed_headers,
) )
if token := os.environ.get("VLLM_API_KEY") or args.api_key:
@app.middleware("http")
async def authentication(request: Request, call_next):
if not request.url.path.startswith("/v1"):
return await call_next(request)
if request.headers.get("Authorization") != "Bearer " + token:
return JSONResponse(content={"error": "Unauthorized"},
status_code=401)
return await call_next(request)
for middleware in args.middleware:
module_path, object_name = middleware.rsplit(".", 1)
imported = getattr(importlib.import_module(module_path), object_name)
if inspect.isclass(imported):
app.add_middleware(imported)
elif inspect.iscoroutinefunction(imported):
app.middleware("http")(imported)
else:
raise ValueError(
f"Invalid middleware {middleware}. Must be a function or a class."
)
logger.info(f"args: {args}") logger.info(f"args: {args}")
if args.served_model_name is not None: if args.served_model_name is not None: