mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:15:20 +08:00
Fix latency benchmark script (#118)
This commit is contained in:
parent
19d2899439
commit
3f942acfe1
@ -1,71 +1,75 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import time
|
import time
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from cacheflow.core.server import (
|
from cacheflow import LLM, SamplingParams
|
||||||
add_server_arguments, process_server_arguments,
|
|
||||||
init_local_server_and_frontend_with_arguments)
|
|
||||||
from cacheflow.sampling_params import SamplingParams
|
|
||||||
|
|
||||||
|
|
||||||
def main(args: argparse.Namespace):
|
def main(args: argparse.Namespace):
|
||||||
server, frontend = init_local_server_and_frontend_with_arguments(args)
|
print(args)
|
||||||
|
|
||||||
|
# Process all the requests in a single batch if possible.
|
||||||
|
# NOTE(woosuk): If the request cannot be processed in a single batch,
|
||||||
|
# the server will automatically process the request in multiple batches.
|
||||||
|
llm = LLM(
|
||||||
|
model=args.model,
|
||||||
|
tensor_parallel_size=args.tensor_parallel_size,
|
||||||
|
max_num_seqs=args.batch_size,
|
||||||
|
max_num_batched_tokens=args.batch_size * args.input_len,
|
||||||
|
)
|
||||||
|
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
n=args.n,
|
n=args.n,
|
||||||
temperature=0.0 if args.use_beam_search else 1.0,
|
temperature=0.0 if args.use_beam_search else 1.0,
|
||||||
top_p=1.0,
|
top_p=1.0,
|
||||||
use_beam_search=args.use_beam_search,
|
use_beam_search=args.use_beam_search,
|
||||||
stop_token_ids=set(),
|
ignore_eos=True,
|
||||||
max_tokens=args.output_len,
|
max_tokens=args.output_len,
|
||||||
)
|
)
|
||||||
print(sampling_params)
|
print(sampling_params)
|
||||||
input_token_ids = [0] * args.input_len
|
dummy_prompts = [""] * args.batch_size
|
||||||
|
dummy_prompt_token_ids = [[0] * args.input_len] * args.batch_size
|
||||||
|
|
||||||
def profile_step(profile=False):
|
def run_to_completion(profile: bool = False):
|
||||||
if profile:
|
if profile:
|
||||||
torch.cuda.cudart().cudaProfilerStart()
|
torch.cuda.cudart().cudaProfilerStart()
|
||||||
for _ in range(args.batch_size):
|
|
||||||
dummy_prompt = ""
|
|
||||||
frontend._add_query(dummy_prompt, input_token_ids, sampling_params)
|
|
||||||
server.add_sequence_groups(frontend.get_inputs())
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
while True:
|
|
||||||
server.step()
|
llm.generate(dummy_prompts, sampling_params, dummy_prompt_token_ids,
|
||||||
if not server.has_unfinished_requests():
|
use_tqdm=False)
|
||||||
break
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
latency = end_time - start_time
|
latency = end_time - start_time
|
||||||
if profile:
|
if profile:
|
||||||
torch.cuda.cudart().cudaProfilerStop()
|
torch.cuda.cudart().cudaProfilerStop()
|
||||||
return latency
|
return latency
|
||||||
|
|
||||||
print("Warm up step")
|
print("Warming up...")
|
||||||
profile_step()
|
run_to_completion(profile=False)
|
||||||
|
|
||||||
# Benchmark.
|
# Benchmark.
|
||||||
latencies = []
|
latencies = []
|
||||||
for _ in tqdm(range(3), desc="Profile step"):
|
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
|
||||||
latencies.append(profile_step())
|
latencies.append(run_to_completion(profile=False))
|
||||||
print(f'Avg latency: {np.mean(latencies)} seconds')
|
print(f'Avg latency: {np.mean(latencies)} seconds')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description='Benchmark the latency of decoding a single sentence.')
|
description='Benchmark the latency of processing a single batch of '
|
||||||
parser = add_server_arguments(parser)
|
'requests till completion.')
|
||||||
|
parser.add_argument('--model', type=str, default='facebook/opt-125m')
|
||||||
|
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
|
||||||
parser.add_argument('--input-len', type=int, default=32)
|
parser.add_argument('--input-len', type=int, default=32)
|
||||||
parser.add_argument('--output-len', type=int, default=128)
|
parser.add_argument('--output-len', type=int, default=128)
|
||||||
parser.add_argument('--batch-size', type=int, default=8)
|
parser.add_argument('--batch-size', type=int, default=8)
|
||||||
parser.add_argument('--n', type=int, default=1)
|
parser.add_argument('--n', type=int, default=1,
|
||||||
|
help='Number of generated sequences per prompt.')
|
||||||
parser.add_argument('--use-beam-search', action='store_true')
|
parser.add_argument('--use-beam-search', action='store_true')
|
||||||
|
parser.add_argument('--num-iters', type=int, default=3,
|
||||||
|
help='Number of iterations to run.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args = process_server_arguments(args)
|
|
||||||
args.max_num_batched_tokens = max(
|
|
||||||
args.max_num_batched_tokens, args.batch_size * args.input_len)
|
|
||||||
print(args)
|
|
||||||
main(args)
|
main(args)
|
||||||
|
|||||||
@ -35,18 +35,26 @@ class LLM:
|
|||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
prompt_token_ids: Optional[List[List[int]]] = None,
|
||||||
use_tqdm: bool = True,
|
use_tqdm: bool = True,
|
||||||
) -> List[RequestOutput]:
|
) -> List[RequestOutput]:
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
|
# Use default sampling params.
|
||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
# Initialize tqdm.
|
# Initialize tqdm.
|
||||||
if use_tqdm:
|
if use_tqdm:
|
||||||
pbar = tqdm(total=len(prompts), desc="Processed prompts")
|
pbar = tqdm(total=len(prompts), desc="Processed prompts")
|
||||||
|
|
||||||
# Add requests to the server.
|
# Add requests to the server.
|
||||||
for prompt in prompts:
|
for i in range(len(prompts)):
|
||||||
|
prompt = prompts[i]
|
||||||
|
if prompt_token_ids is None:
|
||||||
|
token_ids = None
|
||||||
|
else:
|
||||||
|
token_ids = prompt_token_ids[i]
|
||||||
request_id = str(next(self.request_counter))
|
request_id = str(next(self.request_counter))
|
||||||
self.llm_server.add_request(request_id, prompt, sampling_params)
|
self.llm_server.add_request(request_id, prompt, sampling_params,
|
||||||
|
token_ids)
|
||||||
|
|
||||||
# Run the server.
|
# Run the server.
|
||||||
outputs: List[RequestOutput] = []
|
outputs: List[RequestOutput] = []
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user