[Fix] Add chat completion Example and simplify dependencies (#576)

This commit is contained in:
Zhuohan Li 2023-07-25 23:45:48 -07:00 committed by GitHub
parent df5dd3c68e
commit 82ad323dee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 52 additions and 11 deletions

View File

@ -0,0 +1,33 @@
import openai
# Modify OpenAI's API key and API base to use vLLM's API server.
openai.api_key = "EMPTY"
openai.api_base = "http://localhost:8000/v1"
# List models API
models = openai.Model.list()
print("Models:", models)
model = models["data"][0]["id"]
# Chat completion API
chat_completion = openai.ChatCompletion.create(
model=model,
messages=[{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Who won the world series in 2020?"
}, {
"role":
"assistant",
"content":
"The Los Angeles Dodgers won the World Series in 2020."
}, {
"role": "user",
"content": "Where was it played?"
}])
print("Chat completion results:")
print(chat_completion)

View File

@ -3,26 +3,26 @@ import openai
# Modify OpenAI's API key and API base to use vLLM's API server.
openai.api_key = "EMPTY"
openai.api_base = "http://localhost:8000/v1"
model = "facebook/opt-125m"
# Test list models API
# List models API
models = openai.Model.list()
print("Models:", models)
# Test completion API
stream = True
model = models["data"][0]["id"]
# Completion API
stream = False
completion = openai.Completion.create(
model=model,
prompt="A robot may not injure a human being",
echo=False,
n=2,
best_of=3,
stream=stream,
logprobs=3)
# print the completion
print("Completion results:")
if stream:
for c in completion:
print(c)
else:
print("Completion result:", completion)
print(completion)

View File

@ -9,4 +9,3 @@ xformers >= 0.0.19
fastapi
uvicorn
pydantic < 2 # Required for OpenAI server.
fschat # Required for OpenAI ChatCompletion Endpoint.

View File

@ -13,9 +13,6 @@ from fastapi import BackgroundTasks, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from fastchat.conversation import Conversation, SeparatorStyle
from fastchat.model.model_adapter import get_conversation_template
import uvicorn
from vllm.engine.arg_utils import AsyncEngineArgs
@ -33,6 +30,13 @@ from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import random_uuid
try:
from fastchat.conversation import Conversation, SeparatorStyle
from fastchat.model.model_adapter import get_conversation_template
_fastchat_available = True
except ImportError:
_fastchat_available = False
TIMEOUT_KEEP_ALIVE = 5 # seconds
logger = init_logger(__name__)
@ -63,6 +67,11 @@ async def check_model(request) -> Optional[JSONResponse]:
async def get_gen_prompt(request) -> str:
if not _fastchat_available:
raise ModuleNotFoundError(
"fastchat is not installed. Please install fastchat to use "
"the chat completion and conversation APIs: `$ pip install fschat`"
)
conv = get_conversation_template(request.model)
conv = Conversation(
name=conv.name,