diff --git a/vllm/benchmarks/lib/endpoint_request_func.py b/vllm/benchmarks/lib/endpoint_request_func.py index 9d67580be26ad..e640630476630 100644 --- a/vllm/benchmarks/lib/endpoint_request_func.py +++ b/vllm/benchmarks/lib/endpoint_request_func.py @@ -68,6 +68,7 @@ class RequestFuncInput: model: str model_name: Optional[str] = None logprobs: Optional[int] = None + extra_headers: Optional[dict] = None extra_body: Optional[dict] = None multi_modal_content: Optional[Union[dict, list[dict]]] = None ignore_eos: bool = False @@ -129,6 +130,8 @@ async def async_request_openai_completions( headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" } + if request_func_input.extra_headers: + headers |= request_func_input.extra_headers if request_func_input.request_id: headers["x-request-id"] = request_func_input.request_id @@ -258,6 +261,8 @@ async def async_request_openai_chat_completions( "Content-Type": "application/json", "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", } + if request_func_input.extra_headers: + headers |= request_func_input.extra_headers if request_func_input.request_id: headers["x-request-id"] = request_func_input.request_id @@ -364,6 +369,8 @@ async def async_request_openai_audio( headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", } + if request_func_input.extra_headers: + headers |= request_func_input.extra_headers if request_func_input.request_id: headers["x-request-id"] = request_func_input.request_id diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py index a98eb2a78f103..33e831e54bbc9 100644 --- a/vllm/benchmarks/serve.py +++ b/vllm/benchmarks/serve.py @@ -389,6 +389,7 @@ async def benchmark( goodput_config_dict: dict[str, float], max_concurrency: Optional[int], lora_modules: Optional[Iterable[str]], + extra_headers: Optional[dict], extra_body: Optional[dict], ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None, ramp_up_start_rps: Optional[int] = None, @@ -452,6 +453,7 @@ async def benchmark( logprobs=logprobs, multi_modal_content=test_mm_content, ignore_eos=ignore_eos, + extra_headers=extra_headers, extra_body=extra_body, ) @@ -484,6 +486,7 @@ async def benchmark( logprobs=logprobs, multi_modal_content=test_mm_content, ignore_eos=ignore_eos, + extra_headers=extra_headers, extra_body=extra_body) profile_output = await request_func( request_func_input=profile_input, session=session) @@ -568,6 +571,7 @@ async def benchmark( logprobs=logprobs, multi_modal_content=mm_content, ignore_eos=ignore_eos, + extra_headers=extra_headers, extra_body=extra_body, request_id=request_id,) tasks.append( @@ -815,6 +819,15 @@ def add_cli_args(parser: argparse.ArgumentParser): default="/v1/completions", help="API endpoint.", ) + parser.add_argument( + "--header", + metavar="KEY=VALUE", + nargs="*", + help="Key-value pairs (e.g, --header x-additional-info=0.3.3) " + "for headers to be passed with each request. These headers override " \ + "per backend constants and values set via environment variable, and " \ + "will be overriden by other arguments (such as request ids)." + ) parser.add_argument( "--max-concurrency", type=int, @@ -1104,6 +1117,19 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: api_url = f"http://{args.host}:{args.port}{args.endpoint}" base_url = f"http://{args.host}:{args.port}" + # Headers + headers = None + if args.header: + headers = {} + for item in args.header: + if "=" in item: + kvstring = item.split("=", 1) + headers[kvstring[0].strip()] = kvstring[1].strip() + else: + raise ValueError( + "Invalid header format. Please use KEY=VALUE format." + ) + tokenizer = get_tokenizer(tokenizer_id, tokenizer_mode=tokenizer_mode, trust_remote_code=args.trust_remote_code) @@ -1161,6 +1187,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: goodput_config_dict=goodput_config_dict, max_concurrency=args.max_concurrency, lora_modules=args.lora_modules, + extra_headers=headers, extra_body=sampling_params, ramp_up_strategy=args.ramp_up_strategy, ramp_up_start_rps=args.ramp_up_start_rps, @@ -1184,7 +1211,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: if args.metadata: for item in args.metadata: if "=" in item: - kvstring = item.split("=") + kvstring = item.split("=", 1) result_json[kvstring[0].strip()] = kvstring[1].strip() else: raise ValueError(