mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 16:25:02 +08:00
Improve parse_raw_prompt test cases for invalid input .v2 (#30512)
Signed-off-by: Kayvan Mivehnejad <K.Mivehnejad@gmail.com>
This commit is contained in:
parent
dc7fb5bebe
commit
29f7d97715
@ -34,6 +34,13 @@ INPUTS_SLICES = [
|
||||
]
|
||||
|
||||
|
||||
# Test that a nested mixed-type list of lists raises a TypeError.
|
||||
@pytest.mark.parametrize("invalid_input", [[[1, 2], ["foo", "bar"]]])
|
||||
def test_invalid_input_raise_type_error(invalid_input):
|
||||
with pytest.raises(TypeError):
|
||||
parse_raw_prompts(invalid_input)
|
||||
|
||||
|
||||
def test_parse_raw_single_batch_empty():
|
||||
with pytest.raises(ValueError, match="at least one prompt"):
|
||||
parse_raw_prompts([])
|
||||
|
||||
@ -33,22 +33,31 @@ def parse_raw_prompts(
|
||||
if len(prompt) == 0:
|
||||
raise ValueError("please provide at least one prompt")
|
||||
|
||||
# case 2: array of strings
|
||||
if is_list_of(prompt, str):
|
||||
# case 2: array of strings
|
||||
prompt = cast(list[str], prompt)
|
||||
return [TextPrompt(prompt=elem) for elem in prompt]
|
||||
|
||||
# case 3: array of tokens
|
||||
if is_list_of(prompt, int):
|
||||
# case 3: array of tokens
|
||||
prompt = cast(list[int], prompt)
|
||||
return [TokensPrompt(prompt_token_ids=prompt)]
|
||||
if is_list_of(prompt, list):
|
||||
prompt = cast(list[list[int]], prompt)
|
||||
if len(prompt[0]) == 0:
|
||||
raise ValueError("please provide at least one prompt")
|
||||
|
||||
if is_list_of(prompt[0], int):
|
||||
# case 4: array of token arrays
|
||||
return [TokensPrompt(prompt_token_ids=elem) for elem in prompt]
|
||||
# case 4: array of token arrays
|
||||
if is_list_of(prompt, list):
|
||||
first = prompt[0]
|
||||
if not isinstance(first, list):
|
||||
raise ValueError("prompt expected to be a list of lists")
|
||||
|
||||
if len(first) == 0:
|
||||
raise ValueError("Please provide at least one prompt")
|
||||
|
||||
# strict validation: every nested list must be list[int]
|
||||
if not all(is_list_of(elem, int) for elem in prompt):
|
||||
raise TypeError("Nested lists must contain only integers")
|
||||
|
||||
prompt = cast(list[list[int]], prompt)
|
||||
return [TokensPrompt(prompt_token_ids=elem) for elem in prompt]
|
||||
|
||||
raise TypeError(
|
||||
"prompt must be a string, array of strings, "
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user