mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:06:06 +08:00
[Benchmark] Allow arbitrary headers to be passed to benchmarked endpoints (#23937)
Signed-off-by: Clayton Coleman <smarterclayton@gmail.com>
This commit is contained in:
parent
017354c0ef
commit
bc636f21a6
@ -68,6 +68,7 @@ class RequestFuncInput:
|
||||
model: str
|
||||
model_name: Optional[str] = None
|
||||
logprobs: Optional[int] = None
|
||||
extra_headers: Optional[dict] = None
|
||||
extra_body: Optional[dict] = None
|
||||
multi_modal_content: Optional[Union[dict, list[dict]]] = None
|
||||
ignore_eos: bool = False
|
||||
@ -129,6 +130,8 @@ async def async_request_openai_completions(
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
|
||||
}
|
||||
if request_func_input.extra_headers:
|
||||
headers |= request_func_input.extra_headers
|
||||
if request_func_input.request_id:
|
||||
headers["x-request-id"] = request_func_input.request_id
|
||||
|
||||
@ -258,6 +261,8 @@ async def async_request_openai_chat_completions(
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
}
|
||||
if request_func_input.extra_headers:
|
||||
headers |= request_func_input.extra_headers
|
||||
if request_func_input.request_id:
|
||||
headers["x-request-id"] = request_func_input.request_id
|
||||
|
||||
@ -364,6 +369,8 @@ async def async_request_openai_audio(
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
}
|
||||
if request_func_input.extra_headers:
|
||||
headers |= request_func_input.extra_headers
|
||||
if request_func_input.request_id:
|
||||
headers["x-request-id"] = request_func_input.request_id
|
||||
|
||||
|
||||
@ -389,6 +389,7 @@ async def benchmark(
|
||||
goodput_config_dict: dict[str, float],
|
||||
max_concurrency: Optional[int],
|
||||
lora_modules: Optional[Iterable[str]],
|
||||
extra_headers: Optional[dict],
|
||||
extra_body: Optional[dict],
|
||||
ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None,
|
||||
ramp_up_start_rps: Optional[int] = None,
|
||||
@ -452,6 +453,7 @@ async def benchmark(
|
||||
logprobs=logprobs,
|
||||
multi_modal_content=test_mm_content,
|
||||
ignore_eos=ignore_eos,
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
|
||||
@ -484,6 +486,7 @@ async def benchmark(
|
||||
logprobs=logprobs,
|
||||
multi_modal_content=test_mm_content,
|
||||
ignore_eos=ignore_eos,
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body)
|
||||
profile_output = await request_func(
|
||||
request_func_input=profile_input, session=session)
|
||||
@ -568,6 +571,7 @@ async def benchmark(
|
||||
logprobs=logprobs,
|
||||
multi_modal_content=mm_content,
|
||||
ignore_eos=ignore_eos,
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
request_id=request_id,)
|
||||
tasks.append(
|
||||
@ -815,6 +819,15 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
||||
default="/v1/completions",
|
||||
help="API endpoint.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--header",
|
||||
metavar="KEY=VALUE",
|
||||
nargs="*",
|
||||
help="Key-value pairs (e.g, --header x-additional-info=0.3.3) "
|
||||
"for headers to be passed with each request. These headers override " \
|
||||
"per backend constants and values set via environment variable, and " \
|
||||
"will be overriden by other arguments (such as request ids)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-concurrency",
|
||||
type=int,
|
||||
@ -1104,6 +1117,19 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
|
||||
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
|
||||
base_url = f"http://{args.host}:{args.port}"
|
||||
|
||||
# Headers
|
||||
headers = None
|
||||
if args.header:
|
||||
headers = {}
|
||||
for item in args.header:
|
||||
if "=" in item:
|
||||
kvstring = item.split("=", 1)
|
||||
headers[kvstring[0].strip()] = kvstring[1].strip()
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid header format. Please use KEY=VALUE format."
|
||||
)
|
||||
|
||||
tokenizer = get_tokenizer(tokenizer_id,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
trust_remote_code=args.trust_remote_code)
|
||||
@ -1161,6 +1187,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
|
||||
goodput_config_dict=goodput_config_dict,
|
||||
max_concurrency=args.max_concurrency,
|
||||
lora_modules=args.lora_modules,
|
||||
extra_headers=headers,
|
||||
extra_body=sampling_params,
|
||||
ramp_up_strategy=args.ramp_up_strategy,
|
||||
ramp_up_start_rps=args.ramp_up_start_rps,
|
||||
@ -1184,7 +1211,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
|
||||
if args.metadata:
|
||||
for item in args.metadata:
|
||||
if "=" in item:
|
||||
kvstring = item.split("=")
|
||||
kvstring = item.split("=", 1)
|
||||
result_json[kvstring[0].strip()] = kvstring[1].strip()
|
||||
else:
|
||||
raise ValueError(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user