[misc] Add LoRA to benchmark_serving (#12898)

Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
Varun Sundar Rabindranath 2025-02-08 14:45:44 +05:30 committed by GitHub
parent 2880e21e3d
commit 7e1837676a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -537,6 +537,7 @@ async def benchmark(
ignore_eos: bool, ignore_eos: bool,
goodput_config_dict: Dict[str, float], goodput_config_dict: Dict[str, float],
max_concurrency: Optional[int], max_concurrency: Optional[int],
lora_modules: Optional[List[str]],
): ):
if backend in ASYNC_REQUEST_FUNCS: if backend in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS[backend] request_func = ASYNC_REQUEST_FUNCS[backend]
@ -562,6 +563,7 @@ async def benchmark(
multi_modal_content=test_mm_content, multi_modal_content=test_mm_content,
ignore_eos=ignore_eos, ignore_eos=ignore_eos,
) )
test_output = await request_func(request_func_input=test_input) test_output = await request_func(request_func_input=test_input)
if not test_output.success: if not test_output.success:
raise ValueError( raise ValueError(
@ -570,6 +572,11 @@ async def benchmark(
else: else:
print("Initial test run completed. Starting main benchmark run...") print("Initial test run completed. Starting main benchmark run...")
if lora_modules:
# For each input request, choose a LoRA module at random.
lora_modules = iter(
[random.choice(lora_modules) for _ in range(len(input_requests))])
if profile: if profile:
print("Starting profiler...") print("Starting profiler...")
profile_input = RequestFuncInput(model=model_id, profile_input = RequestFuncInput(model=model_id,
@ -616,8 +623,13 @@ async def benchmark(
tasks: List[asyncio.Task] = [] tasks: List[asyncio.Task] = []
async for request in get_request(input_requests, request_rate, burstiness): async for request in get_request(input_requests, request_rate, burstiness):
prompt, prompt_len, output_len, mm_content = request prompt, prompt_len, output_len, mm_content = request
request_func_input = RequestFuncInput(model=model_id, req_model_id, req_model_name = model_id, model_name
model_name=model_name, if lora_modules:
req_lora_module = next(lora_modules)
req_model_id, req_model_name = req_lora_module, req_lora_module
request_func_input = RequestFuncInput(model=req_model_id,
model_name=req_model_name,
prompt=prompt, prompt=prompt,
api_url=api_url, api_url=api_url,
prompt_len=prompt_len, prompt_len=prompt_len,
@ -900,6 +912,7 @@ def main(args: argparse.Namespace):
ignore_eos=args.ignore_eos, ignore_eos=args.ignore_eos,
goodput_config_dict=goodput_config_dict, goodput_config_dict=goodput_config_dict,
max_concurrency=args.max_concurrency, max_concurrency=args.max_concurrency,
lora_modules=args.lora_modules,
)) ))
# Save config and results to json # Save config and results to json
@ -1237,5 +1250,12 @@ if __name__ == "__main__":
"If not specified, the model name will be the " "If not specified, the model name will be the "
"same as the ``--model`` argument. ") "same as the ``--model`` argument. ")
parser.add_argument("--lora-modules",
nargs='+',
default=None,
help="A subset of LoRA module names passed in when "
"launching the server. For each request, the "
"script chooses a LoRA module at random.")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)