mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 22:42:47 +08:00
Support OpenAI API server in benchmark_serving.py (#2172)
This commit is contained in:
parent
dd7e8f5f64
commit
2709c0009a
3
.gitignore
vendored
3
.gitignore
vendored
@ -181,3 +181,6 @@ _build/
|
||||
# hip files generated by PyTorch
|
||||
*.hip
|
||||
*_hip*
|
||||
|
||||
# Benchmark dataset
|
||||
*.json
|
||||
|
||||
@ -24,6 +24,7 @@ from typing import AsyncGenerator, List, Tuple
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
from tqdm.asyncio import tqdm
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
@ -40,15 +41,10 @@ def sample_requests(
|
||||
with open(dataset_path) as f:
|
||||
dataset = json.load(f)
|
||||
# Filter out the conversations with less than 2 turns.
|
||||
dataset = [
|
||||
data for data in dataset
|
||||
if len(data["conversations"]) >= 2
|
||||
]
|
||||
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
||||
# Only keep the first two turns of each conversation.
|
||||
dataset = [
|
||||
(data["conversations"][0]["value"], data["conversations"][1]["value"])
|
||||
for data in dataset
|
||||
]
|
||||
dataset = [(data["conversations"][0]["value"],
|
||||
data["conversations"][1]["value"]) for data in dataset]
|
||||
|
||||
# Tokenize the prompts and completions.
|
||||
prompts = [prompt for prompt, _ in dataset]
|
||||
@ -98,6 +94,7 @@ async def get_request(
|
||||
|
||||
async def send_request(
|
||||
backend: str,
|
||||
model: str,
|
||||
api_url: str,
|
||||
prompt: str,
|
||||
prompt_len: int,
|
||||
@ -120,6 +117,8 @@ async def send_request(
|
||||
"ignore_eos": True,
|
||||
"stream": False,
|
||||
}
|
||||
if model is not None:
|
||||
pload["model"] = model
|
||||
elif backend == "tgi":
|
||||
assert not use_beam_search
|
||||
params = {
|
||||
@ -137,7 +136,8 @@ async def send_request(
|
||||
timeout = aiohttp.ClientTimeout(total=3 * 3600)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
while True:
|
||||
async with session.post(api_url, headers=headers, json=pload) as response:
|
||||
async with session.post(api_url, headers=headers,
|
||||
json=pload) as response:
|
||||
chunks = []
|
||||
async for chunk, _ in response.content.iter_chunks():
|
||||
chunks.append(chunk)
|
||||
@ -155,6 +155,7 @@ async def send_request(
|
||||
|
||||
async def benchmark(
|
||||
backend: str,
|
||||
model: str,
|
||||
api_url: str,
|
||||
input_requests: List[Tuple[str, int, int]],
|
||||
best_of: int,
|
||||
@ -164,11 +165,11 @@ async def benchmark(
|
||||
tasks: List[asyncio.Task] = []
|
||||
async for request in get_request(input_requests, request_rate):
|
||||
prompt, prompt_len, output_len = request
|
||||
task = asyncio.create_task(send_request(backend, api_url, prompt,
|
||||
prompt_len, output_len,
|
||||
best_of, use_beam_search))
|
||||
task = asyncio.create_task(
|
||||
send_request(backend, model, api_url, prompt, prompt_len,
|
||||
output_len, best_of, use_beam_search))
|
||||
tasks.append(task)
|
||||
await asyncio.gather(*tasks)
|
||||
await tqdm.gather(*tasks)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
@ -176,13 +177,15 @@ def main(args: argparse.Namespace):
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
api_url = f"http://{args.host}:{args.port}/generate"
|
||||
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||||
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
|
||||
tokenizer = get_tokenizer(args.tokenizer,
|
||||
trust_remote_code=args.trust_remote_code)
|
||||
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
|
||||
|
||||
benchmark_start_time = time.perf_counter()
|
||||
asyncio.run(benchmark(args.backend, api_url, input_requests, args.best_of,
|
||||
args.use_beam_search, args.request_rate))
|
||||
asyncio.run(
|
||||
benchmark(args.backend, args.model, api_url, input_requests,
|
||||
args.best_of, args.use_beam_search, args.request_rate))
|
||||
benchmark_end_time = time.perf_counter()
|
||||
benchmark_time = benchmark_end_time - benchmark_start_time
|
||||
print(f"Total time: {benchmark_time:.2f} s")
|
||||
@ -196,10 +199,8 @@ def main(args: argparse.Namespace):
|
||||
for prompt_len, output_len, latency in REQUEST_LATENCY
|
||||
])
|
||||
print(f"Average latency per token: {avg_per_token_latency:.2f} s")
|
||||
avg_per_output_token_latency = np.mean([
|
||||
latency / output_len
|
||||
for _, output_len, latency in REQUEST_LATENCY
|
||||
])
|
||||
avg_per_output_token_latency = np.mean(
|
||||
[latency / output_len for _, output_len, latency in REQUEST_LATENCY])
|
||||
print("Average latency per output token: "
|
||||
f"{avg_per_output_token_latency:.2f} s")
|
||||
|
||||
@ -207,27 +208,42 @@ def main(args: argparse.Namespace):
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Benchmark the online serving throughput.")
|
||||
parser.add_argument("--backend", type=str, default="vllm",
|
||||
parser.add_argument("--backend",
|
||||
type=str,
|
||||
default="vllm",
|
||||
choices=["vllm", "tgi"])
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--dataset", type=str, required=True,
|
||||
parser.add_argument("--endpoint", type=str, default="/generate")
|
||||
parser.add_argument("--model", type=str, default=None)
|
||||
parser.add_argument("--dataset",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the dataset.")
|
||||
parser.add_argument("--tokenizer", type=str, required=True,
|
||||
parser.add_argument("--tokenizer",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Name or path of the tokenizer.")
|
||||
parser.add_argument("--best-of", type=int, default=1,
|
||||
parser.add_argument("--best-of",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Generates `best_of` sequences per prompt and "
|
||||
"returns the best one.")
|
||||
"returns the best one.")
|
||||
parser.add_argument("--use-beam-search", action="store_true")
|
||||
parser.add_argument("--num-prompts", type=int, default=1000,
|
||||
parser.add_argument("--num-prompts",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of prompts to process.")
|
||||
parser.add_argument("--request-rate", type=float, default=float("inf"),
|
||||
parser.add_argument("--request-rate",
|
||||
type=float,
|
||||
default=float("inf"),
|
||||
help="Number of requests per second. If this is inf, "
|
||||
"then all the requests are sent at time 0. "
|
||||
"Otherwise, we use Poisson process to synthesize "
|
||||
"the request arrival times.")
|
||||
"then all the requests are sent at time 0. "
|
||||
"Otherwise, we use Poisson process to synthesize "
|
||||
"the request arrival times.")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument('--trust-remote-code', action='store_true',
|
||||
parser.add_argument('--trust-remote-code',
|
||||
action='store_true',
|
||||
help='trust remote code from huggingface')
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user