[Doc] Add docs about OpenAI compatible server (#3288)

This commit is contained in:
Simon Mo 2024-03-18 22:05:34 -07:00 committed by GitHub
parent 6a9c583e73
commit ef65dcfa6f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 383 additions and 161 deletions

View File

@ -1,3 +1,10 @@
sphinx == 6.2.1
sphinx-book-theme == 1.0.1
sphinx-copybutton == 0.5.2
myst-parser == 2.0.0
sphinx-argparse
# packages to install to build the documentation
pydantic
-f https://download.pytorch.org/whl/cpu
torch

View File

@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
# -- Project information -----------------------------------------------------
project = 'vLLM'
copyright = '2023, vLLM Team'
copyright = '2024, vLLM Team'
author = 'the vLLM Team'
# -- General configuration ---------------------------------------------------
@ -37,6 +37,8 @@ extensions = [
"sphinx_copybutton",
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"myst_parser",
"sphinxarg.ext",
]
# Add any paths that contain templates here, relative to this directory.

View File

@ -0,0 +1,4 @@
Sampling Params
===============
.. automodule:: vllm.sampling_params.SamplingParams

View File

@ -69,14 +69,11 @@ Documentation
:maxdepth: 1
:caption: Serving
serving/distributed_serving
serving/run_on_sky
serving/deploying_with_kserve
serving/deploying_with_triton
serving/deploying_with_bentoml
serving/openai_compatible_server
serving/deploying_with_docker
serving/serving_with_langchain
serving/distributed_serving
serving/metrics
serving/integrations
.. toctree::
:maxdepth: 1
@ -98,6 +95,7 @@ Documentation
:maxdepth: 2
:caption: Developer Documentation
dev/sampling_params
dev/engine/engine_index
dev/kernel/paged_attention

View File

@ -90,7 +90,7 @@ Requests can specify the LoRA adapter as if it were any other model via the ``mo
processed according to the server-wide LoRA configuration (i.e. in parallel with base model requests, and potentially other
LoRA adapter requests if they were provided and ``max_loras`` is set high enough).
The following is an example request
The following is an example request
.. code-block:: bash

View File

@ -0,0 +1,11 @@
Integrations
------------
.. toctree::
:maxdepth: 1
run_on_sky
deploying_with_kserve
deploying_with_triton
deploying_with_bentoml
serving_with_langchain

View File

@ -0,0 +1,114 @@
# OpenAI Compatible Server
vLLM provides an HTTP server that implements OpenAI's [Completions](https://platform.openai.com/docs/api-reference/completions) and [Chat](https://platform.openai.com/docs/api-reference/chat) API.
You can start the server using Python, or using [Docker](deploying_with_docker.rst):
```bash
python -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-7b-hf --dtype float32 --api-key token-abc123
```
To call the server, you can use the official OpenAI Python client library, or any other HTTP client.
```python
from openai import OpenAI
client = OpenAI(
base_url="http://localhost:8000/v1",
api_key="token-abc123",
)
completion = client.chat.completions.create(
model="meta-llama/Llama-2-7b-hf",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"}
]
)
print(completion.choices[0].message)
```
## API Reference
Please see the [OpenAI API Reference](https://platform.openai.com/docs/api-reference) for more information on the API. We support all parameters except:
- Chat: `tools`, and `tool_choice`.
- Completions: `suffix`.
## Extra Parameters
vLLM supports a set of parameters that are not part of the OpenAI API.
In order to use them, you can pass them as extra parameters in the OpenAI client.
Or directly merge them into the JSON payload if you are using HTTP call directly.
```python
completion = client.chat.completions.create(
model="meta-llama/Llama-2-7b-hf",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"}
],
extra_body={
"guided_choice": ["positive", "negative"]
}
)
```
### Extra Parameters for Chat API
The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported.
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python
:start-after: begin-chat-completion-sampling-params
:end-before: end-chat-completion-sampling-params
```
The following extra parameters are supported:
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python
:start-after: begin-chat-completion-extra-params
:end-before: end-chat-completion-extra-params
```
### Extra Parameters for Completions API
The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported.
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python
:start-after: begin-completion-sampling-params
:end-before: end-completion-sampling-params
```
The following extra parameters are supported:
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python
:start-after: begin-completion-extra-params
:end-before: end-completion-extra-params
```
## Chat Template
In order for the language model to support chat protocol, vLLM requires the model to include
a chat template in its tokenizer configuration. The chat template is a Jinja2 template that
specifies how are roles, messages, and other chat-specific tokens are encoded in the input.
An example chat template for `meta-llama/Llama-2-7b-chat-hf` can be found [here](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/09bd0f49e16738cdfaa6e615203e126038736eb0/tokenizer_config.json#L12)
Some models do not provide a chat template even though they are instruction/chat fine-tuned. For those model,
you can manually specify their chat template in the `--chat-template` parameter with the file path to the chat
template, or the template in string form. Without a chat template, the server will not be able to process chat
and all chat requests will error.
```bash
python -m vllm.entrypoints.openai.api_server \
--model ... \
--chat-template ./path-to-chat-template.jinja
```
vLLM community provides a set of chat templates for popular models. You can find them in the examples
directory [here](https://github.com/vllm-project/vllm/tree/main/examples/)
## Command line arguments for the server
```{argparse}
:module: vllm.entrypoints.openai.cli_args
:func: make_arg_parser
:prog: vllm-openai-server
```

View File

@ -1,11 +1,8 @@
import argparse
import asyncio
import json
from contextlib import asynccontextmanager
import os
import importlib
import inspect
import ssl
from prometheus_client import make_asgi_app
import fastapi
@ -23,9 +20,9 @@ from vllm.entrypoints.openai.protocol import (CompletionRequest,
ChatCompletionRequest,
ErrorResponse)
from vllm.logger import init_logger
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_engine import LoRA
TIMEOUT_KEEP_ALIVE = 5 # seconds
@ -51,109 +48,8 @@ async def lifespan(app: fastapi.FastAPI):
app = fastapi.FastAPI(lifespan=lifespan)
class LoRAParserAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
lora_list = []
for item in values:
name, path = item.split('=')
lora_list.append(LoRA(name, path))
setattr(namespace, self.dest, lora_list)
def parse_args():
parser = argparse.ArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.")
parser.add_argument("--host", type=str, default=None, help="host name")
parser.add_argument("--port", type=int, default=8000, help="port number")
parser.add_argument(
"--uvicorn-log-level",
type=str,
default="info",
choices=['debug', 'info', 'warning', 'error', 'critical', 'trace'],
help="log level for uvicorn")
parser.add_argument("--allow-credentials",
action="store_true",
help="allow credentials")
parser.add_argument("--allowed-origins",
type=json.loads,
default=["*"],
help="allowed origins")
parser.add_argument("--allowed-methods",
type=json.loads,
default=["*"],
help="allowed methods")
parser.add_argument("--allowed-headers",
type=json.loads,
default=["*"],
help="allowed headers")
parser.add_argument("--api-key",
type=str,
default=None,
help="If provided, the server will require this key "
"to be presented in the header.")
parser.add_argument("--served-model-name",
type=str,
default=None,
help="The model name used in the API. If not "
"specified, the model name will be the same as "
"the huggingface name.")
parser.add_argument(
"--lora-modules",
type=str,
default=None,
nargs='+',
action=LoRAParserAction,
help="LoRA module configurations in the format name=path. "
"Multiple modules can be specified.")
parser.add_argument("--chat-template",
type=str,
default=None,
help="The file path to the chat template, "
"or the template in single-line form "
"for the specified model")
parser.add_argument("--response-role",
type=str,
default="assistant",
help="The role name to return if "
"`request.add_generation_prompt=true`.")
parser.add_argument("--ssl-keyfile",
type=str,
default=None,
help="The file path to the SSL key file")
parser.add_argument("--ssl-certfile",
type=str,
default=None,
help="The file path to the SSL cert file")
parser.add_argument("--ssl-ca-certs",
type=str,
default=None,
help="The CA certificates file")
parser.add_argument(
"--ssl-cert-reqs",
type=int,
default=int(ssl.CERT_NONE),
help="Whether client certificate is required (see stdlib ssl module's)"
)
parser.add_argument(
"--root-path",
type=str,
default=None,
help="FastAPI root_path when app is behind a path based routing proxy")
parser.add_argument(
"--middleware",
type=str,
action="append",
default=[],
help="Additional ASGI middleware to apply to the app. "
"We accept multiple --middleware arguments. "
"The value should be an import path. "
"If a function is provided, vLLM will add it to the server "
"using @app.middleware('http'). "
"If a class is provided, vLLM will add it to the server "
"using app.add_middleware(). ")
parser = AsyncEngineArgs.add_cli_args(parser)
parser = make_arg_parser()
return parser.parse_args()

View File

@ -0,0 +1,118 @@
"""
This file contains the command line arguments for the vLLM's
OpenAI-compatible server. It is kept in a separate file for documentation
purposes.
"""
import argparse
import json
import ssl
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.serving_engine import LoRA
class LoRAParserAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
lora_list = []
for item in values:
name, path = item.split('=')
lora_list.append(LoRA(name, path))
setattr(namespace, self.dest, lora_list)
def make_arg_parser():
parser = argparse.ArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.")
parser.add_argument("--host", type=str, default=None, help="host name")
parser.add_argument("--port", type=int, default=8000, help="port number")
parser.add_argument(
"--uvicorn-log-level",
type=str,
default="info",
choices=['debug', 'info', 'warning', 'error', 'critical', 'trace'],
help="log level for uvicorn")
parser.add_argument("--allow-credentials",
action="store_true",
help="allow credentials")
parser.add_argument("--allowed-origins",
type=json.loads,
default=["*"],
help="allowed origins")
parser.add_argument("--allowed-methods",
type=json.loads,
default=["*"],
help="allowed methods")
parser.add_argument("--allowed-headers",
type=json.loads,
default=["*"],
help="allowed headers")
parser.add_argument("--api-key",
type=str,
default=None,
help="If provided, the server will require this key "
"to be presented in the header.")
parser.add_argument("--served-model-name",
type=str,
default=None,
help="The model name used in the API. If not "
"specified, the model name will be the same as "
"the huggingface name.")
parser.add_argument(
"--lora-modules",
type=str,
default=None,
nargs='+',
action=LoRAParserAction,
help="LoRA module configurations in the format name=path. "
"Multiple modules can be specified.")
parser.add_argument("--chat-template",
type=str,
default=None,
help="The file path to the chat template, "
"or the template in single-line form "
"for the specified model")
parser.add_argument("--response-role",
type=str,
default="assistant",
help="The role name to return if "
"`request.add_generation_prompt=true`.")
parser.add_argument("--ssl-keyfile",
type=str,
default=None,
help="The file path to the SSL key file")
parser.add_argument("--ssl-certfile",
type=str,
default=None,
help="The file path to the SSL cert file")
parser.add_argument("--ssl-ca-certs",
type=str,
default=None,
help="The CA certificates file")
parser.add_argument(
"--ssl-cert-reqs",
type=int,
default=int(ssl.CERT_NONE),
help="Whether client certificate is required (see stdlib ssl module's)"
)
parser.add_argument(
"--root-path",
type=str,
default=None,
help="FastAPI root_path when app is behind a path based routing proxy")
parser.add_argument(
"--middleware",
type=str,
action="append",
default=[],
help="Additional ASGI middleware to apply to the app. "
"We accept multiple --middleware arguments. "
"The value should be an import path. "
"If a function is provided, vLLM will add it to the server "
"using @app.middleware('http'). "
"If a class is provided, vLLM will add it to the server "
"using app.add_middleware(). ")
parser = AsyncEngineArgs.add_cli_args(parser)
return parser

View File

@ -61,41 +61,80 @@ class ResponseFormat(BaseModel):
class ChatCompletionRequest(BaseModel):
model: str
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/chat/create
messages: List[Dict[str, str]]
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
n: Optional[int] = 1
model: str
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = None
max_tokens: Optional[int] = None
n: Optional[int] = 1
presence_penalty: Optional[float] = 0.0
response_format: Optional[ResponseFormat] = None
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = None
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
user: Optional[str] = None
# Additional parameters supported by vLLM
# doc: begin-chat-completion-sampling-params
best_of: Optional[int] = None
top_k: Optional[int] = -1
ignore_eos: Optional[bool] = False
use_beam_search: Optional[bool] = False
top_k: Optional[int] = -1
min_p: Optional[float] = 0.0
repetition_penalty: Optional[float] = 1.0
length_penalty: Optional[float] = 1.0
early_stopping: Optional[bool] = False
ignore_eos: Optional[bool] = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True
add_generation_prompt: Optional[bool] = True
echo: Optional[bool] = False
repetition_penalty: Optional[float] = 1.0
min_p: Optional[float] = 0.0
include_stop_str_in_output: Optional[bool] = False
length_penalty: Optional[float] = 1.0
guided_json: Optional[Union[str, dict, BaseModel]] = None
guided_regex: Optional[str] = None
guided_choice: Optional[List[str]] = None
guided_grammar: Optional[str] = None
response_format: Optional[ResponseFormat] = None
# doc: end-chat-completion-sampling-params
# doc: begin-chat-completion-extra-params
echo: Optional[bool] = Field(
default=False,
description=(
"If true, the new message will be prepended with the last message "
"if they belong to the same role."),
)
add_generation_prompt: Optional[bool] = Field(
default=True,
description=
("If true, the generation prompt will be added to the chat template. "
"This is a parameter used by chat template in tokenizer config of the "
"model."),
)
include_stop_str_in_output: Optional[bool] = Field(
default=False,
description=(
"Whether to include the stop string in the output. "
"This is only applied when the stop or stop_token_ids is set."),
)
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
default=None,
description=("If specified, the output will follow the JSON schema."),
)
guided_regex: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the regex pattern."),
)
guided_choice: Optional[List[str]] = Field(
default=None,
description=(
"If specified, the output will be exactly one of the choices."),
)
guided_grammar: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the context free grammar."),
)
# doc: end-chat-completion-extra-params
def to_sampling_params(self) -> SamplingParams:
if self.logprobs and not self.top_logprobs:
@ -157,41 +196,74 @@ class ChatCompletionRequest(BaseModel):
class CompletionRequest(BaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/completions/create
model: str
# a string, array of strings, array of tokens, or array of token arrays
prompt: Union[List[int], List[List[int]], str, List[str]]
suffix: Optional[str] = None
best_of: Optional[int] = None
echo: Optional[bool] = False
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[int] = None
max_tokens: Optional[int] = 16
n: Optional[int] = 1
presence_penalty: Optional[float] = 0.0
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
suffix: Optional[str] = None
temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0
n: Optional[int] = 1
stream: Optional[bool] = False
logprobs: Optional[int] = None
echo: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
seed: Optional[int] = None
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
best_of: Optional[int] = None
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None
# Additional parameters supported by vLLM
top_k: Optional[int] = -1
ignore_eos: Optional[bool] = False
# doc: begin-completion-sampling-params
use_beam_search: Optional[bool] = False
top_k: Optional[int] = -1
min_p: Optional[float] = 0.0
repetition_penalty: Optional[float] = 1.0
length_penalty: Optional[float] = 1.0
early_stopping: Optional[bool] = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
ignore_eos: Optional[bool] = False
skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True
repetition_penalty: Optional[float] = 1.0
min_p: Optional[float] = 0.0
include_stop_str_in_output: Optional[bool] = False
length_penalty: Optional[float] = 1.0
guided_json: Optional[Union[str, dict, BaseModel]] = None
guided_regex: Optional[str] = None
guided_choice: Optional[List[str]] = None
guided_grammar: Optional[str] = None
response_format: Optional[ResponseFormat] = None
# doc: end-completion-sampling-params
# doc: begin-completion-extra-params
include_stop_str_in_output: Optional[bool] = Field(
default=False,
description=(
"Whether to include the stop string in the output. "
"This is only applied when the stop or stop_token_ids is set."),
)
response_format: Optional[ResponseFormat] = Field(
default=None,
description=
("Similar to chat completion, this parameter specifies the format of "
"output. Only {'type': 'json_object'} or {'type': 'text' } is "
"supported."),
)
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
default=None,
description=("If specified, the output will follow the JSON schema."),
)
guided_regex: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the regex pattern."),
)
guided_choice: Optional[List[str]] = Field(
default=None,
description=(
"If specified, the output will be exactly one of the choices."),
)
guided_grammar: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the context free grammar."),
)
# doc: end-completion-extra-params
def to_sampling_params(self):
echo_without_generation = self.echo and self.max_tokens == 0