diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py index ca8d218581e77..6d52b51a9fcd0 100644 --- a/vllm/benchmarks/serve.py +++ b/vllm/benchmarks/serve.py @@ -948,7 +948,10 @@ def add_cli_args(parser: argparse.ArgumentParser): ) -def main(args: argparse.Namespace): +def main(args: argparse.Namespace) -> dict[str, Any]: + return asyncio.run(main_async(args)) + +async def main_async(args: argparse.Namespace) -> dict[str, Any]: print(args) random.seed(args.seed) np.random.seed(args.seed) @@ -1025,8 +1028,7 @@ def main(args: argparse.Namespace): gc.collect() gc.freeze() - benchmark_result = asyncio.run( - benchmark( + benchmark_result = await benchmark( endpoint_type=args.endpoint_type, api_url=api_url, base_url=base_url, @@ -1052,62 +1054,62 @@ def main(args: argparse.Namespace): ramp_up_start_rps=args.ramp_up_start_rps, ramp_up_end_rps=args.ramp_up_end_rps, ready_check_timeout_sec=args.ready_check_timeout_sec, - )) + ) # Save config and results to json - if args.save_result or args.append_result: - result_json: dict[str, Any] = {} + result_json: dict[str, Any] = {} - # Setup - current_dt = datetime.now().strftime("%Y%m%d-%H%M%S") - result_json["date"] = current_dt - result_json["endpoint_type"] = args.endpoint_type - result_json["label"] = label - result_json["model_id"] = model_id - result_json["tokenizer_id"] = tokenizer_id - result_json["num_prompts"] = args.num_prompts + # Setup + current_dt = datetime.now().strftime("%Y%m%d-%H%M%S") + result_json["date"] = current_dt + result_json["endpoint_type"] = args.endpoint_type + result_json["label"] = label + result_json["model_id"] = model_id + result_json["tokenizer_id"] = tokenizer_id + result_json["num_prompts"] = args.num_prompts - # Metadata - if args.metadata: - for item in args.metadata: - if "=" in item: - kvstring = item.split("=") - result_json[kvstring[0].strip()] = kvstring[1].strip() - else: - raise ValueError( - "Invalid metadata format. Please use KEY=VALUE format." - ) + # Metadata + if args.metadata: + for item in args.metadata: + if "=" in item: + kvstring = item.split("=") + result_json[kvstring[0].strip()] = kvstring[1].strip() + else: + raise ValueError( + "Invalid metadata format. Please use KEY=VALUE format." + ) - # Traffic - result_json["request_rate"] = (args.request_rate if args.request_rate - < float("inf") else "inf") - result_json["burstiness"] = args.burstiness - result_json["max_concurrency"] = args.max_concurrency + # Traffic + result_json["request_rate"] = (args.request_rate if args.request_rate + < float("inf") else "inf") + result_json["burstiness"] = args.burstiness + result_json["max_concurrency"] = args.max_concurrency - if args.ramp_up_strategy is not None: - result_json["ramp_up_strategy"] = args.ramp_up_strategy - result_json["ramp_up_start_rps"] = args.ramp_up_start_rps - result_json["ramp_up_end_rps"] = args.ramp_up_end_rps + if args.ramp_up_strategy is not None: + result_json["ramp_up_strategy"] = args.ramp_up_strategy + result_json["ramp_up_start_rps"] = args.ramp_up_start_rps + result_json["ramp_up_end_rps"] = args.ramp_up_end_rps - # Merge with benchmark result - result_json = {**result_json, **benchmark_result} + # Merge with benchmark result + result_json = {**result_json, **benchmark_result} - if not args.save_detailed: - # Remove fields with too many data points - for field in [ - "input_lens", - "output_lens", - "ttfts", - "itls", - "generated_texts", - "errors", - ]: - if field in result_json: - del result_json[field] - if field in benchmark_result: - del benchmark_result[field] + if not args.save_detailed: + # Remove fields with too many data points + for field in [ + "input_lens", + "output_lens", + "ttfts", + "itls", + "generated_texts", + "errors", + ]: + if field in result_json: + del result_json[field] + if field in benchmark_result: + del benchmark_result[field] # Save to file + if args.save_result or args.append_result: base_model_id = model_id.split("/")[-1] max_concurrency_str = (f"-concurrency{args.max_concurrency}" if args.max_concurrency is not None else "") @@ -1129,3 +1131,5 @@ def main(args: argparse.Namespace): outfile.write("\n") json.dump(result_json, outfile) save_to_pytorch_benchmark_format(args, result_json, file_name) + + return result_json \ No newline at end of file