Strengthen input validation and tests for 'parse_raw_prompts’. (#30652)

Signed-off-by: Kayvan Mivehnejad <K.Mivehnejad@gmail.com>
This commit is contained in:
Kayvan Mivehnejad 2025-12-18 14:51:58 -05:00 committed by GitHub
parent 24b65eff0d
commit 634a14bd7d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -45,16 +45,17 @@ def parse_raw_prompts(
# case 4: array of token arrays # case 4: array of token arrays
if is_list_of(prompt, list): if is_list_of(prompt, list):
first = prompt[0] if len(prompt) == 1 and isinstance(prompt[0], list) and len(prompt[0]) == 0:
if not isinstance(first, list): raise ValueError("please provide at least one prompt")
raise ValueError("prompt expected to be a list of lists") for elem in prompt:
if not isinstance(elem, list):
if len(first) == 0: raise TypeError(
raise ValueError("Please provide at least one prompt") "prompt must be a list of lists, but found a non-list element."
)
# strict validation: every nested list must be list[int] if not is_list_of(elem, int):
if not all(is_list_of(elem, int) for elem in prompt): raise TypeError(
raise TypeError("Nested lists must contain only integers") "Nested lists of tokens must contain only integers."
)
prompt = cast(list[list[int]], prompt) prompt = cast(list[list[int]], prompt)
return [TokensPrompt(prompt_token_ids=elem) for elem in prompt] return [TokensPrompt(prompt_token_ids=elem) for elem in prompt]