mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-26 02:44:27 +08:00
[Bench] Split serve.py:main into async/async versions (#22405)
Signed-off-by: Linkun <github@lkchen.net>
This commit is contained in:
parent
2a4c825523
commit
4d4297e8fe
@ -948,7 +948,10 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def main(args: argparse.Namespace):
|
def main(args: argparse.Namespace) -> dict[str, Any]:
|
||||||
|
return asyncio.run(main_async(args))
|
||||||
|
|
||||||
|
async def main_async(args: argparse.Namespace) -> dict[str, Any]:
|
||||||
print(args)
|
print(args)
|
||||||
random.seed(args.seed)
|
random.seed(args.seed)
|
||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
@ -1025,8 +1028,7 @@ def main(args: argparse.Namespace):
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
gc.freeze()
|
gc.freeze()
|
||||||
|
|
||||||
benchmark_result = asyncio.run(
|
benchmark_result = await benchmark(
|
||||||
benchmark(
|
|
||||||
endpoint_type=args.endpoint_type,
|
endpoint_type=args.endpoint_type,
|
||||||
api_url=api_url,
|
api_url=api_url,
|
||||||
base_url=base_url,
|
base_url=base_url,
|
||||||
@ -1052,62 +1054,62 @@ def main(args: argparse.Namespace):
|
|||||||
ramp_up_start_rps=args.ramp_up_start_rps,
|
ramp_up_start_rps=args.ramp_up_start_rps,
|
||||||
ramp_up_end_rps=args.ramp_up_end_rps,
|
ramp_up_end_rps=args.ramp_up_end_rps,
|
||||||
ready_check_timeout_sec=args.ready_check_timeout_sec,
|
ready_check_timeout_sec=args.ready_check_timeout_sec,
|
||||||
))
|
)
|
||||||
|
|
||||||
# Save config and results to json
|
# Save config and results to json
|
||||||
if args.save_result or args.append_result:
|
result_json: dict[str, Any] = {}
|
||||||
result_json: dict[str, Any] = {}
|
|
||||||
|
|
||||||
# Setup
|
# Setup
|
||||||
current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
|
current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||||
result_json["date"] = current_dt
|
result_json["date"] = current_dt
|
||||||
result_json["endpoint_type"] = args.endpoint_type
|
result_json["endpoint_type"] = args.endpoint_type
|
||||||
result_json["label"] = label
|
result_json["label"] = label
|
||||||
result_json["model_id"] = model_id
|
result_json["model_id"] = model_id
|
||||||
result_json["tokenizer_id"] = tokenizer_id
|
result_json["tokenizer_id"] = tokenizer_id
|
||||||
result_json["num_prompts"] = args.num_prompts
|
result_json["num_prompts"] = args.num_prompts
|
||||||
|
|
||||||
# Metadata
|
# Metadata
|
||||||
if args.metadata:
|
if args.metadata:
|
||||||
for item in args.metadata:
|
for item in args.metadata:
|
||||||
if "=" in item:
|
if "=" in item:
|
||||||
kvstring = item.split("=")
|
kvstring = item.split("=")
|
||||||
result_json[kvstring[0].strip()] = kvstring[1].strip()
|
result_json[kvstring[0].strip()] = kvstring[1].strip()
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Invalid metadata format. Please use KEY=VALUE format."
|
"Invalid metadata format. Please use KEY=VALUE format."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Traffic
|
# Traffic
|
||||||
result_json["request_rate"] = (args.request_rate if args.request_rate
|
result_json["request_rate"] = (args.request_rate if args.request_rate
|
||||||
< float("inf") else "inf")
|
< float("inf") else "inf")
|
||||||
result_json["burstiness"] = args.burstiness
|
result_json["burstiness"] = args.burstiness
|
||||||
result_json["max_concurrency"] = args.max_concurrency
|
result_json["max_concurrency"] = args.max_concurrency
|
||||||
|
|
||||||
if args.ramp_up_strategy is not None:
|
if args.ramp_up_strategy is not None:
|
||||||
result_json["ramp_up_strategy"] = args.ramp_up_strategy
|
result_json["ramp_up_strategy"] = args.ramp_up_strategy
|
||||||
result_json["ramp_up_start_rps"] = args.ramp_up_start_rps
|
result_json["ramp_up_start_rps"] = args.ramp_up_start_rps
|
||||||
result_json["ramp_up_end_rps"] = args.ramp_up_end_rps
|
result_json["ramp_up_end_rps"] = args.ramp_up_end_rps
|
||||||
|
|
||||||
# Merge with benchmark result
|
# Merge with benchmark result
|
||||||
result_json = {**result_json, **benchmark_result}
|
result_json = {**result_json, **benchmark_result}
|
||||||
|
|
||||||
if not args.save_detailed:
|
if not args.save_detailed:
|
||||||
# Remove fields with too many data points
|
# Remove fields with too many data points
|
||||||
for field in [
|
for field in [
|
||||||
"input_lens",
|
"input_lens",
|
||||||
"output_lens",
|
"output_lens",
|
||||||
"ttfts",
|
"ttfts",
|
||||||
"itls",
|
"itls",
|
||||||
"generated_texts",
|
"generated_texts",
|
||||||
"errors",
|
"errors",
|
||||||
]:
|
]:
|
||||||
if field in result_json:
|
if field in result_json:
|
||||||
del result_json[field]
|
del result_json[field]
|
||||||
if field in benchmark_result:
|
if field in benchmark_result:
|
||||||
del benchmark_result[field]
|
del benchmark_result[field]
|
||||||
|
|
||||||
# Save to file
|
# Save to file
|
||||||
|
if args.save_result or args.append_result:
|
||||||
base_model_id = model_id.split("/")[-1]
|
base_model_id = model_id.split("/")[-1]
|
||||||
max_concurrency_str = (f"-concurrency{args.max_concurrency}"
|
max_concurrency_str = (f"-concurrency{args.max_concurrency}"
|
||||||
if args.max_concurrency is not None else "")
|
if args.max_concurrency is not None else "")
|
||||||
@ -1129,3 +1131,5 @@ def main(args: argparse.Namespace):
|
|||||||
outfile.write("\n")
|
outfile.write("\n")
|
||||||
json.dump(result_json, outfile)
|
json.dump(result_json, outfile)
|
||||||
save_to_pytorch_benchmark_format(args, result_json, file_name)
|
save_to_pytorch_benchmark_format(args, result_json, file_name)
|
||||||
|
|
||||||
|
return result_json
|
||||||
Loading…
x
Reference in New Issue
Block a user