[Benchmark] Allow arbitrary headers to be passed to benchmarked endpoints (#23937)

Signed-off-by: Clayton Coleman <smarterclayton@gmail.com>
This commit is contained in:
Clayton Coleman 2025-09-12 16:57:53 -04:00 committed by GitHub
parent 017354c0ef
commit bc636f21a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 35 additions and 1 deletions

View File

@ -68,6 +68,7 @@ class RequestFuncInput:
model: str model: str
model_name: Optional[str] = None model_name: Optional[str] = None
logprobs: Optional[int] = None logprobs: Optional[int] = None
extra_headers: Optional[dict] = None
extra_body: Optional[dict] = None extra_body: Optional[dict] = None
multi_modal_content: Optional[Union[dict, list[dict]]] = None multi_modal_content: Optional[Union[dict, list[dict]]] = None
ignore_eos: bool = False ignore_eos: bool = False
@ -129,6 +130,8 @@ async def async_request_openai_completions(
headers = { headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
} }
if request_func_input.extra_headers:
headers |= request_func_input.extra_headers
if request_func_input.request_id: if request_func_input.request_id:
headers["x-request-id"] = request_func_input.request_id headers["x-request-id"] = request_func_input.request_id
@ -258,6 +261,8 @@ async def async_request_openai_chat_completions(
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
} }
if request_func_input.extra_headers:
headers |= request_func_input.extra_headers
if request_func_input.request_id: if request_func_input.request_id:
headers["x-request-id"] = request_func_input.request_id headers["x-request-id"] = request_func_input.request_id
@ -364,6 +369,8 @@ async def async_request_openai_audio(
headers = { headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
} }
if request_func_input.extra_headers:
headers |= request_func_input.extra_headers
if request_func_input.request_id: if request_func_input.request_id:
headers["x-request-id"] = request_func_input.request_id headers["x-request-id"] = request_func_input.request_id

View File

@ -389,6 +389,7 @@ async def benchmark(
goodput_config_dict: dict[str, float], goodput_config_dict: dict[str, float],
max_concurrency: Optional[int], max_concurrency: Optional[int],
lora_modules: Optional[Iterable[str]], lora_modules: Optional[Iterable[str]],
extra_headers: Optional[dict],
extra_body: Optional[dict], extra_body: Optional[dict],
ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None, ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None,
ramp_up_start_rps: Optional[int] = None, ramp_up_start_rps: Optional[int] = None,
@ -452,6 +453,7 @@ async def benchmark(
logprobs=logprobs, logprobs=logprobs,
multi_modal_content=test_mm_content, multi_modal_content=test_mm_content,
ignore_eos=ignore_eos, ignore_eos=ignore_eos,
extra_headers=extra_headers,
extra_body=extra_body, extra_body=extra_body,
) )
@ -484,6 +486,7 @@ async def benchmark(
logprobs=logprobs, logprobs=logprobs,
multi_modal_content=test_mm_content, multi_modal_content=test_mm_content,
ignore_eos=ignore_eos, ignore_eos=ignore_eos,
extra_headers=extra_headers,
extra_body=extra_body) extra_body=extra_body)
profile_output = await request_func( profile_output = await request_func(
request_func_input=profile_input, session=session) request_func_input=profile_input, session=session)
@ -568,6 +571,7 @@ async def benchmark(
logprobs=logprobs, logprobs=logprobs,
multi_modal_content=mm_content, multi_modal_content=mm_content,
ignore_eos=ignore_eos, ignore_eos=ignore_eos,
extra_headers=extra_headers,
extra_body=extra_body, extra_body=extra_body,
request_id=request_id,) request_id=request_id,)
tasks.append( tasks.append(
@ -815,6 +819,15 @@ def add_cli_args(parser: argparse.ArgumentParser):
default="/v1/completions", default="/v1/completions",
help="API endpoint.", help="API endpoint.",
) )
parser.add_argument(
"--header",
metavar="KEY=VALUE",
nargs="*",
help="Key-value pairs (e.g, --header x-additional-info=0.3.3) "
"for headers to be passed with each request. These headers override " \
"per backend constants and values set via environment variable, and " \
"will be overriden by other arguments (such as request ids)."
)
parser.add_argument( parser.add_argument(
"--max-concurrency", "--max-concurrency",
type=int, type=int,
@ -1104,6 +1117,19 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
api_url = f"http://{args.host}:{args.port}{args.endpoint}" api_url = f"http://{args.host}:{args.port}{args.endpoint}"
base_url = f"http://{args.host}:{args.port}" base_url = f"http://{args.host}:{args.port}"
# Headers
headers = None
if args.header:
headers = {}
for item in args.header:
if "=" in item:
kvstring = item.split("=", 1)
headers[kvstring[0].strip()] = kvstring[1].strip()
else:
raise ValueError(
"Invalid header format. Please use KEY=VALUE format."
)
tokenizer = get_tokenizer(tokenizer_id, tokenizer = get_tokenizer(tokenizer_id,
tokenizer_mode=tokenizer_mode, tokenizer_mode=tokenizer_mode,
trust_remote_code=args.trust_remote_code) trust_remote_code=args.trust_remote_code)
@ -1161,6 +1187,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
goodput_config_dict=goodput_config_dict, goodput_config_dict=goodput_config_dict,
max_concurrency=args.max_concurrency, max_concurrency=args.max_concurrency,
lora_modules=args.lora_modules, lora_modules=args.lora_modules,
extra_headers=headers,
extra_body=sampling_params, extra_body=sampling_params,
ramp_up_strategy=args.ramp_up_strategy, ramp_up_strategy=args.ramp_up_strategy,
ramp_up_start_rps=args.ramp_up_start_rps, ramp_up_start_rps=args.ramp_up_start_rps,
@ -1184,7 +1211,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
if args.metadata: if args.metadata:
for item in args.metadata: for item in args.metadata:
if "=" in item: if "=" in item:
kvstring = item.split("=") kvstring = item.split("=", 1)
result_json[kvstring[0].strip()] = kvstring[1].strip() result_json[kvstring[0].strip()] = kvstring[1].strip()
else: else:
raise ValueError( raise ValueError(