[Bugfix][Frontend] validate arg priority in frontend LLM class before add request (#27596)

Signed-off-by: Junpu Fan <junpufan@gmail.com>
This commit is contained in:
Junpu Fan 2025-10-28 07:02:43 -07:00 committed by GitHub
parent 2abbd351ef
commit b186149e8e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 26 additions and 0 deletions

View File

@ -71,6 +71,26 @@ def test_multiple_sampling_params(llm: LLM):
assert len(PROMPTS) == len(outputs)
def test_multiple_priority(llm: LLM):
# Generate works when priority is None
outputs = llm.generate(PROMPTS, sampling_params=None, priority=None)
assert len(PROMPTS) == len(outputs)
# Generate works when length of priority is same as the len(PROMPTS)
outputs = llm.generate(PROMPTS, sampling_params=None, priority=[0] * len(PROMPTS))
assert len(PROMPTS) == len(outputs)
# Exception raised, if the length of priority does not match the length of prompts
with pytest.raises(ValueError):
outputs = llm.generate(
PROMPTS, sampling_params=None, priority=[0] * (len(PROMPTS) - 1)
)
# Exception raised, if the priority list is empty
with pytest.raises(ValueError):
outputs = llm.generate(PROMPTS, sampling_params=None, priority=[])
def test_max_model_len():
max_model_len = 20
llm = LLM(

View File

@ -1565,6 +1565,12 @@ class LLM:
raise ValueError(
"The lengths of prompts and lora_request must be the same."
)
if priority is not None and len(priority) != num_requests:
raise ValueError(
"The lengths of prompts "
f"({num_requests}) and priority ({len(priority)}) "
"must be the same."
)
for sp in params if isinstance(params, Sequence) else (params,):
if isinstance(sp, SamplingParams):