diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index 71289277eb987..5e7795a14072f 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -45,16 +45,17 @@ def parse_raw_prompts( # case 4: array of token arrays if is_list_of(prompt, list): - first = prompt[0] - if not isinstance(first, list): - raise ValueError("prompt expected to be a list of lists") - - if len(first) == 0: - raise ValueError("Please provide at least one prompt") - - # strict validation: every nested list must be list[int] - if not all(is_list_of(elem, int) for elem in prompt): - raise TypeError("Nested lists must contain only integers") + if len(prompt) == 1 and isinstance(prompt[0], list) and len(prompt[0]) == 0: + raise ValueError("please provide at least one prompt") + for elem in prompt: + if not isinstance(elem, list): + raise TypeError( + "prompt must be a list of lists, but found a non-list element." + ) + if not is_list_of(elem, int): + raise TypeError( + "Nested lists of tokens must contain only integers." + ) prompt = cast(list[list[int]], prompt) return [TokensPrompt(prompt_token_ids=elem) for elem in prompt]