Fix latency benchmark script (#118)

This commit is contained in:
Woosuk Kwon 2023-05-22 17:03:40 -07:00 committed by GitHub
parent 19d2899439
commit 3f942acfe1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 43 additions and 31 deletions

View File

@ -1,71 +1,75 @@
import argparse import argparse
import time import time
from typing import List
from tqdm import tqdm
import numpy as np import numpy as np
import torch import torch
from tqdm import tqdm
from cacheflow.core.server import ( from cacheflow import LLM, SamplingParams
add_server_arguments, process_server_arguments,
init_local_server_and_frontend_with_arguments)
from cacheflow.sampling_params import SamplingParams
def main(args: argparse.Namespace): def main(args: argparse.Namespace):
server, frontend = init_local_server_and_frontend_with_arguments(args) print(args)
# Process all the requests in a single batch if possible.
# NOTE(woosuk): If the request cannot be processed in a single batch,
# the server will automatically process the request in multiple batches.
llm = LLM(
model=args.model,
tensor_parallel_size=args.tensor_parallel_size,
max_num_seqs=args.batch_size,
max_num_batched_tokens=args.batch_size * args.input_len,
)
sampling_params = SamplingParams( sampling_params = SamplingParams(
n=args.n, n=args.n,
temperature=0.0 if args.use_beam_search else 1.0, temperature=0.0 if args.use_beam_search else 1.0,
top_p=1.0, top_p=1.0,
use_beam_search=args.use_beam_search, use_beam_search=args.use_beam_search,
stop_token_ids=set(), ignore_eos=True,
max_tokens=args.output_len, max_tokens=args.output_len,
) )
print(sampling_params) print(sampling_params)
input_token_ids = [0] * args.input_len dummy_prompts = [""] * args.batch_size
dummy_prompt_token_ids = [[0] * args.input_len] * args.batch_size
def profile_step(profile=False): def run_to_completion(profile: bool = False):
if profile: if profile:
torch.cuda.cudart().cudaProfilerStart() torch.cuda.cudart().cudaProfilerStart()
for _ in range(args.batch_size):
dummy_prompt = ""
frontend._add_query(dummy_prompt, input_token_ids, sampling_params)
server.add_sequence_groups(frontend.get_inputs())
start_time = time.time() start_time = time.time()
while True:
server.step() llm.generate(dummy_prompts, sampling_params, dummy_prompt_token_ids,
if not server.has_unfinished_requests(): use_tqdm=False)
break
end_time = time.time() end_time = time.time()
latency = end_time - start_time latency = end_time - start_time
if profile: if profile:
torch.cuda.cudart().cudaProfilerStop() torch.cuda.cudart().cudaProfilerStop()
return latency return latency
print("Warm up step") print("Warming up...")
profile_step() run_to_completion(profile=False)
# Benchmark. # Benchmark.
latencies = [] latencies = []
for _ in tqdm(range(3), desc="Profile step"): for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
latencies.append(profile_step()) latencies.append(run_to_completion(profile=False))
print(f'Avg latency: {np.mean(latencies)} seconds') print(f'Avg latency: {np.mean(latencies)} seconds')
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Benchmark the latency of decoding a single sentence.') description='Benchmark the latency of processing a single batch of '
parser = add_server_arguments(parser) 'requests till completion.')
parser.add_argument('--model', type=str, default='facebook/opt-125m')
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--input-len', type=int, default=32) parser.add_argument('--input-len', type=int, default=32)
parser.add_argument('--output-len', type=int, default=128) parser.add_argument('--output-len', type=int, default=128)
parser.add_argument('--batch-size', type=int, default=8) parser.add_argument('--batch-size', type=int, default=8)
parser.add_argument('--n', type=int, default=1) parser.add_argument('--n', type=int, default=1,
help='Number of generated sequences per prompt.')
parser.add_argument('--use-beam-search', action='store_true') parser.add_argument('--use-beam-search', action='store_true')
parser.add_argument('--num-iters', type=int, default=3,
help='Number of iterations to run.')
args = parser.parse_args() args = parser.parse_args()
args = process_server_arguments(args)
args.max_num_batched_tokens = max(
args.max_num_batched_tokens, args.batch_size * args.input_len)
print(args)
main(args) main(args)

View File

@ -35,18 +35,26 @@ class LLM:
self, self,
prompts: List[str], prompts: List[str],
sampling_params: Optional[SamplingParams] = None, sampling_params: Optional[SamplingParams] = None,
prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True, use_tqdm: bool = True,
) -> List[RequestOutput]: ) -> List[RequestOutput]:
if sampling_params is None: if sampling_params is None:
# Use default sampling params.
sampling_params = SamplingParams() sampling_params = SamplingParams()
# Initialize tqdm. # Initialize tqdm.
if use_tqdm: if use_tqdm:
pbar = tqdm(total=len(prompts), desc="Processed prompts") pbar = tqdm(total=len(prompts), desc="Processed prompts")
# Add requests to the server. # Add requests to the server.
for prompt in prompts: for i in range(len(prompts)):
prompt = prompts[i]
if prompt_token_ids is None:
token_ids = None
else:
token_ids = prompt_token_ids[i]
request_id = str(next(self.request_counter)) request_id = str(next(self.request_counter))
self.llm_server.add_request(request_id, prompt, sampling_params) self.llm_server.add_request(request_id, prompt, sampling_params,
token_ids)
# Run the server. # Run the server.
outputs: List[RequestOutput] = [] outputs: List[RequestOutput] = []