From b186149e8e9d067c17b5067f73c420b2d8317580 Mon Sep 17 00:00:00 2001 From: Junpu Fan Date: Tue, 28 Oct 2025 07:02:43 -0700 Subject: [PATCH] [Bugfix][Frontend] validate arg priority in frontend LLM class before add request (#27596) Signed-off-by: Junpu Fan --- tests/entrypoints/llm/test_generate.py | 20 ++++++++++++++++++++ vllm/entrypoints/llm.py | 6 ++++++ 2 files changed, 26 insertions(+) diff --git a/tests/entrypoints/llm/test_generate.py b/tests/entrypoints/llm/test_generate.py index e9993fd840619..34465b7d27080 100644 --- a/tests/entrypoints/llm/test_generate.py +++ b/tests/entrypoints/llm/test_generate.py @@ -71,6 +71,26 @@ def test_multiple_sampling_params(llm: LLM): assert len(PROMPTS) == len(outputs) +def test_multiple_priority(llm: LLM): + # Generate works when priority is None + outputs = llm.generate(PROMPTS, sampling_params=None, priority=None) + assert len(PROMPTS) == len(outputs) + + # Generate works when length of priority is same as the len(PROMPTS) + outputs = llm.generate(PROMPTS, sampling_params=None, priority=[0] * len(PROMPTS)) + assert len(PROMPTS) == len(outputs) + + # Exception raised, if the length of priority does not match the length of prompts + with pytest.raises(ValueError): + outputs = llm.generate( + PROMPTS, sampling_params=None, priority=[0] * (len(PROMPTS) - 1) + ) + + # Exception raised, if the priority list is empty + with pytest.raises(ValueError): + outputs = llm.generate(PROMPTS, sampling_params=None, priority=[]) + + def test_max_model_len(): max_model_len = 20 llm = LLM( diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index ce5cf0aae3a37..758e16c89e694 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1565,6 +1565,12 @@ class LLM: raise ValueError( "The lengths of prompts and lora_request must be the same." ) + if priority is not None and len(priority) != num_requests: + raise ValueError( + "The lengths of prompts " + f"({num_requests}) and priority ({len(priority)}) " + "must be the same." + ) for sp in params if isinstance(params, Sequence) else (params,): if isinstance(sp, SamplingParams):