diff --git a/vllm/benchmarks/throughput.py b/vllm/benchmarks/throughput.py index 04bc29b07aac9..78a6c96ebb4f3 100644 --- a/vllm/benchmarks/throughput.py +++ b/vllm/benchmarks/throughput.py @@ -59,16 +59,16 @@ def run_vllm( prompts: list[Union[TextPrompt, TokensPrompt]] = [] sampling_params: list[SamplingParams] = [] for request in requests: - prompts.append( - TokensPrompt( - prompt_token_ids=request.prompt["prompt_token_ids"], - multi_modal_data=request.multi_modal_data, - ) + prompt = ( + TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"]) if "prompt_token_ids" in request.prompt - else TextPrompt( - prompt=request.prompt, multi_modal_data=request.multi_modal_data - ) + else TextPrompt(prompt=request.prompt) ) + if request.multi_modal_data: + assert isinstance(request.multi_modal_data, dict) + prompt["multi_modal_data"] = request.multi_modal_data + prompts.append(prompt) + sampling_params.append( SamplingParams( n=n,