diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index a887e7150dc78..daaf5d46bf142 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -275,7 +275,7 @@ async def benchmark( model_id: str, model_name: str, tokenizer: PreTrainedTokenizerBase, - input_requests: list[SampleRequest], + requests: list[SampleRequest], logprobs: Optional[int], request_rate: float, burstiness: float, @@ -295,12 +295,14 @@ async def benchmark( raise ValueError(f"Unknown backend: {backend}") print("Starting initial single prompt test run...") + last_idx = len(requests) - 1 test_prompt, test_prompt_len, test_output_len, test_mm_content = ( - input_requests[0].prompt, - input_requests[0].prompt_len, - input_requests[0].expected_output_len, - input_requests[0].multi_modal_data, + requests[last_idx].prompt, + requests[last_idx].prompt_len, + requests[last_idx].expected_output_len, + requests[last_idx].multi_modal_data, ) + input_requests = requests[:last_idx] assert test_mm_content is None or isinstance(test_mm_content, dict) test_input = RequestFuncInput( @@ -615,6 +617,9 @@ def main(args: argparse.Namespace): api_url = f"http://{args.host}:{args.port}{args.endpoint}" base_url = f"http://{args.host}:{args.port}" + # Create one more request (for a test request) + total_prompts = args.num_prompts + 1 + tokenizer = get_tokenizer( tokenizer_id, tokenizer_mode=tokenizer_mode, @@ -632,7 +637,7 @@ def main(args: argparse.Namespace): # For the "sonnet" dataset, formatting depends on the backend. if args.backend == "openai-chat": input_requests = dataset.sample( - num_requests=args.num_prompts, + num_requests=total_prompts, input_len=args.sonnet_input_len, output_len=args.sonnet_output_len, prefix_len=args.sonnet_prefix_len, @@ -644,7 +649,7 @@ def main(args: argparse.Namespace): "Tokenizer/model must have chat template for sonnet dataset." ) input_requests = dataset.sample( - num_requests=args.num_prompts, + num_requests=total_prompts, input_len=args.sonnet_input_len, output_len=args.sonnet_output_len, prefix_len=args.sonnet_prefix_len, @@ -707,7 +712,7 @@ def main(args: argparse.Namespace): dataset_split=args.hf_split, random_seed=args.seed, ).sample( - num_requests=args.num_prompts, + num_requests=total_prompts, tokenizer=tokenizer, output_len=args.hf_output_len, ) @@ -719,15 +724,15 @@ def main(args: argparse.Namespace): random_seed=args.seed, dataset_path=args.dataset_path ).sample( tokenizer=tokenizer, - num_requests=args.num_prompts, + num_requests=total_prompts, output_len=args.sharegpt_output_len, ), "burstgpt": lambda: BurstGPTDataset( random_seed=args.seed, dataset_path=args.dataset_path - ).sample(tokenizer=tokenizer, num_requests=args.num_prompts), + ).sample(tokenizer=tokenizer, num_requests=total_prompts), "random": lambda: RandomDataset(dataset_path=args.dataset_path).sample( tokenizer=tokenizer, - num_requests=args.num_prompts, + num_requests=total_prompts, prefix_len=args.random_prefix_len, input_len=args.random_input_len, output_len=args.random_output_len, @@ -774,7 +779,7 @@ def main(args: argparse.Namespace): model_id=model_id, model_name=model_name, tokenizer=tokenizer, - input_requests=input_requests, + requests=input_requests, logprobs=args.logprobs, request_rate=args.request_rate, burstiness=args.burstiness,