diff --git a/docs/source/getting_started/quickstart.rst b/docs/source/getting_started/quickstart.rst index 5ce3c096cb44..7c44a96865a5 100644 --- a/docs/source/getting_started/quickstart.rst +++ b/docs/source/getting_started/quickstart.rst @@ -63,38 +63,10 @@ Call ``llm.generate`` to generate the outputs. It adds the input prompts to vLLM The code example can also be found in `examples/offline_inference.py `_. - -API Server ----------- - -vLLM can be deployed as an LLM service. We provide an example `FastAPI `_ server. Check `vllm/entrypoints/api_server.py `_ for the server implementation. The server uses ``AsyncLLMEngine`` class to support asynchronous processing of incoming requests. - -Start the server: - -.. code-block:: console - - $ python -m vllm.entrypoints.api_server - -By default, this command starts the server at ``http://localhost:8000`` with the OPT-125M model. - -Query the model in shell: - -.. code-block:: console - - $ curl http://localhost:8000/generate \ - $ -d '{ - $ "prompt": "San Francisco is a", - $ "use_beam_search": true, - $ "n": 4, - $ "temperature": 0 - $ }' - -See `examples/api_client.py `_ for a more detailed client example. - OpenAI-Compatible Server ------------------------ -vLLM can be deployed as a server that mimics the OpenAI API protocol. This allows vLLM to be used as a drop-in replacement for applications using OpenAI API. +vLLM can be deployed as a server that implements the OpenAI API protocol. This allows vLLM to be used as a drop-in replacement for applications using OpenAI API. By default, it starts the server at ``http://localhost:8000``. You can specify the address with ``--host`` and ``--port`` arguments. The server currently hosts one model at a time (OPT-125M in the command below) and implements `list models `_, `create chat completion `_, and `create completion `_ endpoints. We are actively adding support for more endpoints. Start the server: @@ -118,6 +90,8 @@ This server can be queried in the same format as OpenAI API. For example, list t $ curl http://localhost:8000/v1/models +You can pass in the argument ``--api-key`` or environment variable ``VLLM_API_KEY`` to enable the server to check for API key in the header. + Using OpenAI Completions API with vLLM ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index b10e83903737..deb0fddd643c 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -2,6 +2,10 @@ import argparse import asyncio import json from contextlib import asynccontextmanager +import os +import importlib +import inspect + from aioprometheus import MetricsMiddleware from aioprometheus.asgi.starlette import metrics import fastapi @@ -64,6 +68,13 @@ def parse_args(): type=json.loads, default=["*"], help="allowed headers") + parser.add_argument( + "--api-key", + type=str, + default=None, + help= + "If provided, the server will require this key to be presented in the header." + ) parser.add_argument("--served-model-name", type=str, default=None, @@ -94,6 +105,17 @@ def parse_args(): type=str, default=None, help="FastAPI root_path when app is behind a path based routing proxy") + parser.add_argument( + "--middleware", + type=str, + action="append", + default=[], + help="Additional ASGI middleware to apply to the app. " + "We accept multiple --middleware arguments. " + "The value should be an import path. " + "If a function is provided, vLLM will add it to the server using @app.middleware('http'). " + "If a class is provided, vLLM will add it to the server using app.add_middleware(). " + ) parser = AsyncEngineArgs.add_cli_args(parser) return parser.parse_args() @@ -161,6 +183,29 @@ if __name__ == "__main__": allow_headers=args.allowed_headers, ) + if token := os.environ.get("VLLM_API_KEY") or args.api_key: + + @app.middleware("http") + async def authentication(request: Request, call_next): + if not request.url.path.startswith("/v1"): + return await call_next(request) + if request.headers.get("Authorization") != "Bearer " + token: + return JSONResponse(content={"error": "Unauthorized"}, + status_code=401) + return await call_next(request) + + for middleware in args.middleware: + module_path, object_name = middleware.rsplit(".", 1) + imported = getattr(importlib.import_module(module_path), object_name) + if inspect.isclass(imported): + app.add_middleware(imported) + elif inspect.iscoroutinefunction(imported): + app.middleware("http")(imported) + else: + raise ValueError( + f"Invalid middleware {middleware}. Must be a function or a class." + ) + logger.info(f"args: {args}") if args.served_model_name is not None: