[Benchmark] Allow arbitrary headers to be passed to benchmarked endpoints (#23937)

Signed-off-by: Clayton Coleman <smarterclayton@gmail.com>
This commit is contained in:
Clayton Coleman 2025-09-12 16:57:53 -04:00 committed by GitHub
parent 017354c0ef
commit bc636f21a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 35 additions and 1 deletions

View File

@ -68,6 +68,7 @@ class RequestFuncInput:
model: str
model_name: Optional[str] = None
logprobs: Optional[int] = None
extra_headers: Optional[dict] = None
extra_body: Optional[dict] = None
multi_modal_content: Optional[Union[dict, list[dict]]] = None
ignore_eos: bool = False
@ -129,6 +130,8 @@ async def async_request_openai_completions(
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
}
if request_func_input.extra_headers:
headers |= request_func_input.extra_headers
if request_func_input.request_id:
headers["x-request-id"] = request_func_input.request_id
@ -258,6 +261,8 @@ async def async_request_openai_chat_completions(
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
}
if request_func_input.extra_headers:
headers |= request_func_input.extra_headers
if request_func_input.request_id:
headers["x-request-id"] = request_func_input.request_id
@ -364,6 +369,8 @@ async def async_request_openai_audio(
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
}
if request_func_input.extra_headers:
headers |= request_func_input.extra_headers
if request_func_input.request_id:
headers["x-request-id"] = request_func_input.request_id

View File

@ -389,6 +389,7 @@ async def benchmark(
goodput_config_dict: dict[str, float],
max_concurrency: Optional[int],
lora_modules: Optional[Iterable[str]],
extra_headers: Optional[dict],
extra_body: Optional[dict],
ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None,
ramp_up_start_rps: Optional[int] = None,
@ -452,6 +453,7 @@ async def benchmark(
logprobs=logprobs,
multi_modal_content=test_mm_content,
ignore_eos=ignore_eos,
extra_headers=extra_headers,
extra_body=extra_body,
)
@ -484,6 +486,7 @@ async def benchmark(
logprobs=logprobs,
multi_modal_content=test_mm_content,
ignore_eos=ignore_eos,
extra_headers=extra_headers,
extra_body=extra_body)
profile_output = await request_func(
request_func_input=profile_input, session=session)
@ -568,6 +571,7 @@ async def benchmark(
logprobs=logprobs,
multi_modal_content=mm_content,
ignore_eos=ignore_eos,
extra_headers=extra_headers,
extra_body=extra_body,
request_id=request_id,)
tasks.append(
@ -815,6 +819,15 @@ def add_cli_args(parser: argparse.ArgumentParser):
default="/v1/completions",
help="API endpoint.",
)
parser.add_argument(
"--header",
metavar="KEY=VALUE",
nargs="*",
help="Key-value pairs (e.g, --header x-additional-info=0.3.3) "
"for headers to be passed with each request. These headers override " \
"per backend constants and values set via environment variable, and " \
"will be overriden by other arguments (such as request ids)."
)
parser.add_argument(
"--max-concurrency",
type=int,
@ -1104,6 +1117,19 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
base_url = f"http://{args.host}:{args.port}"
# Headers
headers = None
if args.header:
headers = {}
for item in args.header:
if "=" in item:
kvstring = item.split("=", 1)
headers[kvstring[0].strip()] = kvstring[1].strip()
else:
raise ValueError(
"Invalid header format. Please use KEY=VALUE format."
)
tokenizer = get_tokenizer(tokenizer_id,
tokenizer_mode=tokenizer_mode,
trust_remote_code=args.trust_remote_code)
@ -1161,6 +1187,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
goodput_config_dict=goodput_config_dict,
max_concurrency=args.max_concurrency,
lora_modules=args.lora_modules,
extra_headers=headers,
extra_body=sampling_params,
ramp_up_strategy=args.ramp_up_strategy,
ramp_up_start_rps=args.ramp_up_start_rps,
@ -1184,7 +1211,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
if args.metadata:
for item in args.metadata:
if "=" in item:
kvstring = item.split("=")
kvstring = item.split("=", 1)
result_json[kvstring[0].strip()] = kvstring[1].strip()
else:
raise ValueError(