[Fix] Add chat completion Example and simplify dependencies (#576)

This commit is contained in:
Zhuohan Li 2023-07-25 23:45:48 -07:00 committed by GitHub
parent df5dd3c68e
commit 82ad323dee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 52 additions and 11 deletions

View File

@ -0,0 +1,33 @@
import openai
# Modify OpenAI's API key and API base to use vLLM's API server.
openai.api_key = "EMPTY"
openai.api_base = "http://localhost:8000/v1"
# List models API
models = openai.Model.list()
print("Models:", models)
model = models["data"][0]["id"]
# Chat completion API
chat_completion = openai.ChatCompletion.create(
model=model,
messages=[{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Who won the world series in 2020?"
}, {
"role":
"assistant",
"content":
"The Los Angeles Dodgers won the World Series in 2020."
}, {
"role": "user",
"content": "Where was it played?"
}])
print("Chat completion results:")
print(chat_completion)

View File

@ -3,26 +3,26 @@ import openai
# Modify OpenAI's API key and API base to use vLLM's API server. # Modify OpenAI's API key and API base to use vLLM's API server.
openai.api_key = "EMPTY" openai.api_key = "EMPTY"
openai.api_base = "http://localhost:8000/v1" openai.api_base = "http://localhost:8000/v1"
model = "facebook/opt-125m"
# Test list models API # List models API
models = openai.Model.list() models = openai.Model.list()
print("Models:", models) print("Models:", models)
# Test completion API model = models["data"][0]["id"]
stream = True
# Completion API
stream = False
completion = openai.Completion.create( completion = openai.Completion.create(
model=model, model=model,
prompt="A robot may not injure a human being", prompt="A robot may not injure a human being",
echo=False, echo=False,
n=2, n=2,
best_of=3,
stream=stream, stream=stream,
logprobs=3) logprobs=3)
# print the completion print("Completion results:")
if stream: if stream:
for c in completion: for c in completion:
print(c) print(c)
else: else:
print("Completion result:", completion) print(completion)

View File

@ -9,4 +9,3 @@ xformers >= 0.0.19
fastapi fastapi
uvicorn uvicorn
pydantic < 2 # Required for OpenAI server. pydantic < 2 # Required for OpenAI server.
fschat # Required for OpenAI ChatCompletion Endpoint.

View File

@ -13,9 +13,6 @@ from fastapi import BackgroundTasks, Request
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from fastchat.conversation import Conversation, SeparatorStyle
from fastchat.model.model_adapter import get_conversation_template
import uvicorn import uvicorn
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
@ -33,6 +30,13 @@ from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import random_uuid from vllm.utils import random_uuid
try:
from fastchat.conversation import Conversation, SeparatorStyle
from fastchat.model.model_adapter import get_conversation_template
_fastchat_available = True
except ImportError:
_fastchat_available = False
TIMEOUT_KEEP_ALIVE = 5 # seconds TIMEOUT_KEEP_ALIVE = 5 # seconds
logger = init_logger(__name__) logger = init_logger(__name__)
@ -63,6 +67,11 @@ async def check_model(request) -> Optional[JSONResponse]:
async def get_gen_prompt(request) -> str: async def get_gen_prompt(request) -> str:
if not _fastchat_available:
raise ModuleNotFoundError(
"fastchat is not installed. Please install fastchat to use "
"the chat completion and conversation APIs: `$ pip install fschat`"
)
conv = get_conversation_template(request.model) conv = get_conversation_template(request.model)
conv = Conversation( conv = Conversation(
name=conv.name, name=conv.name,