Improve parse_raw_prompt test cases for invalid input .v2 (#30512)

Signed-off-by: Kayvan Mivehnejad <K.Mivehnejad@gmail.com>
This commit is contained in:
Kayvan Mivehnejad 2025-12-13 22:18:41 -05:00 committed by GitHub
parent dc7fb5bebe
commit 29f7d97715
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 25 additions and 9 deletions

View File

@ -34,6 +34,13 @@ INPUTS_SLICES = [
]
# Test that a nested mixed-type list of lists raises a TypeError.
@pytest.mark.parametrize("invalid_input", [[[1, 2], ["foo", "bar"]]])
def test_invalid_input_raise_type_error(invalid_input):
with pytest.raises(TypeError):
parse_raw_prompts(invalid_input)
def test_parse_raw_single_batch_empty():
with pytest.raises(ValueError, match="at least one prompt"):
parse_raw_prompts([])

View File

@ -33,22 +33,31 @@ def parse_raw_prompts(
if len(prompt) == 0:
raise ValueError("please provide at least one prompt")
# case 2: array of strings
if is_list_of(prompt, str):
# case 2: array of strings
prompt = cast(list[str], prompt)
return [TextPrompt(prompt=elem) for elem in prompt]
# case 3: array of tokens
if is_list_of(prompt, int):
# case 3: array of tokens
prompt = cast(list[int], prompt)
return [TokensPrompt(prompt_token_ids=prompt)]
if is_list_of(prompt, list):
prompt = cast(list[list[int]], prompt)
if len(prompt[0]) == 0:
raise ValueError("please provide at least one prompt")
if is_list_of(prompt[0], int):
# case 4: array of token arrays
return [TokensPrompt(prompt_token_ids=elem) for elem in prompt]
# case 4: array of token arrays
if is_list_of(prompt, list):
first = prompt[0]
if not isinstance(first, list):
raise ValueError("prompt expected to be a list of lists")
if len(first) == 0:
raise ValueError("Please provide at least one prompt")
# strict validation: every nested list must be list[int]
if not all(is_list_of(elem, int) for elem in prompt):
raise TypeError("Nested lists must contain only integers")
prompt = cast(list[list[int]], prompt)
return [TokensPrompt(prompt_token_ids=elem) for elem in prompt]
raise TypeError(
"prompt must be a string, array of strings, "