mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 13:35:54 +08:00
[PD] add test for chat completions endpoint (#21925)
Signed-off-by: Abirdcfly <fp544037857@gmail.com>
This commit is contained in:
parent
845420ac2c
commit
0d7db16a92
@ -51,20 +51,31 @@ def check_vllm_server(url: str, timeout=5, retries=3) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def run_simple_prompt(base_url: str, model_name: str,
|
||||
input_prompt: str) -> str:
|
||||
def run_simple_prompt(base_url: str, model_name: str, input_prompt: str,
|
||||
use_chat_endpoint: bool) -> str:
|
||||
client = openai.OpenAI(api_key="EMPTY", base_url=base_url)
|
||||
completion = client.completions.create(model=model_name,
|
||||
prompt=input_prompt,
|
||||
max_tokens=MAX_OUTPUT_LEN,
|
||||
temperature=0.0,
|
||||
seed=42)
|
||||
if use_chat_endpoint:
|
||||
completion = client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": input_prompt
|
||||
}]
|
||||
}],
|
||||
max_completion_tokens=MAX_OUTPUT_LEN,
|
||||
temperature=0.0,
|
||||
seed=42)
|
||||
return completion.choices[0].message.content
|
||||
else:
|
||||
completion = client.completions.create(model=model_name,
|
||||
prompt=input_prompt,
|
||||
max_tokens=MAX_OUTPUT_LEN,
|
||||
temperature=0.0,
|
||||
seed=42)
|
||||
|
||||
# print("-" * 50)
|
||||
# print(f"Completion results for {model_name}:")
|
||||
# print(completion)
|
||||
# print("-" * 50)
|
||||
return completion.choices[0].text
|
||||
return completion.choices[0].text
|
||||
|
||||
|
||||
def main():
|
||||
@ -125,10 +136,12 @@ def main():
|
||||
f"vllm server: {args.service_url} is not ready yet!")
|
||||
|
||||
output_strs = dict()
|
||||
for prompt in SAMPLE_PROMPTS:
|
||||
for i, prompt in enumerate(SAMPLE_PROMPTS):
|
||||
use_chat_endpoint = (i % 2 == 1)
|
||||
output_str = run_simple_prompt(base_url=service_url,
|
||||
model_name=args.model_name,
|
||||
input_prompt=prompt)
|
||||
input_prompt=prompt,
|
||||
use_chat_endpoint=use_chat_endpoint)
|
||||
print(f"Prompt: {prompt}, output: {output_str}")
|
||||
output_strs[prompt] = output_str
|
||||
|
||||
|
||||
@ -162,6 +162,8 @@ async def send_request_to_service(client_info: dict, endpoint: str,
|
||||
}
|
||||
req_data["stream"] = False
|
||||
req_data["max_tokens"] = 1
|
||||
if "max_completion_tokens" in req_data:
|
||||
req_data["max_completion_tokens"] = 1
|
||||
if "stream_options" in req_data:
|
||||
del req_data["stream_options"]
|
||||
headers = {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user