[PD] add test for chat completions endpoint (#21925)

Signed-off-by: Abirdcfly <fp544037857@gmail.com>
This commit is contained in:
Abirdcfly 2025-08-04 10:57:03 +08:00 committed by GitHub
parent 845420ac2c
commit 0d7db16a92
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 29 additions and 14 deletions

View File

@ -51,20 +51,31 @@ def check_vllm_server(url: str, timeout=5, retries=3) -> bool:
return False return False
def run_simple_prompt(base_url: str, model_name: str, def run_simple_prompt(base_url: str, model_name: str, input_prompt: str,
input_prompt: str) -> str: use_chat_endpoint: bool) -> str:
client = openai.OpenAI(api_key="EMPTY", base_url=base_url) client = openai.OpenAI(api_key="EMPTY", base_url=base_url)
completion = client.completions.create(model=model_name, if use_chat_endpoint:
prompt=input_prompt, completion = client.chat.completions.create(
max_tokens=MAX_OUTPUT_LEN, model=model_name,
temperature=0.0, messages=[{
seed=42) "role": "user",
"content": [{
"type": "text",
"text": input_prompt
}]
}],
max_completion_tokens=MAX_OUTPUT_LEN,
temperature=0.0,
seed=42)
return completion.choices[0].message.content
else:
completion = client.completions.create(model=model_name,
prompt=input_prompt,
max_tokens=MAX_OUTPUT_LEN,
temperature=0.0,
seed=42)
# print("-" * 50) return completion.choices[0].text
# print(f"Completion results for {model_name}:")
# print(completion)
# print("-" * 50)
return completion.choices[0].text
def main(): def main():
@ -125,10 +136,12 @@ def main():
f"vllm server: {args.service_url} is not ready yet!") f"vllm server: {args.service_url} is not ready yet!")
output_strs = dict() output_strs = dict()
for prompt in SAMPLE_PROMPTS: for i, prompt in enumerate(SAMPLE_PROMPTS):
use_chat_endpoint = (i % 2 == 1)
output_str = run_simple_prompt(base_url=service_url, output_str = run_simple_prompt(base_url=service_url,
model_name=args.model_name, model_name=args.model_name,
input_prompt=prompt) input_prompt=prompt,
use_chat_endpoint=use_chat_endpoint)
print(f"Prompt: {prompt}, output: {output_str}") print(f"Prompt: {prompt}, output: {output_str}")
output_strs[prompt] = output_str output_strs[prompt] = output_str

View File

@ -162,6 +162,8 @@ async def send_request_to_service(client_info: dict, endpoint: str,
} }
req_data["stream"] = False req_data["stream"] = False
req_data["max_tokens"] = 1 req_data["max_tokens"] = 1
if "max_completion_tokens" in req_data:
req_data["max_completion_tokens"] = 1
if "stream_options" in req_data: if "stream_options" in req_data:
del req_data["stream_options"] del req_data["stream_options"]
headers = { headers = {