From cb234955dfd4cbad552f4bfe1de6c5a3981766a7 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 2 May 2025 23:11:53 +0800 Subject: [PATCH] [Misc] Clean up input processing (#17582) Signed-off-by: DarkLight1337 --- .../multimodal/pooling/test_intern_vit.py | 14 +- vllm/engine/async_llm_engine.py | 4 - vllm/engine/llm_engine.py | 34 +- vllm/engine/protocol.py | 3 + vllm/entrypoints/llm.py | 6 +- vllm/inputs/data.py | 23 +- vllm/inputs/parse.py | 27 +- vllm/inputs/preprocess.py | 521 ++++++++++-------- vllm/multimodal/processing.py | 12 +- 9 files changed, 359 insertions(+), 285 deletions(-) diff --git a/tests/models/multimodal/pooling/test_intern_vit.py b/tests/models/multimodal/pooling/test_intern_vit.py index 038405ded9ebe..76f9fbe025505 100644 --- a/tests/models/multimodal/pooling/test_intern_vit.py +++ b/tests/models/multimodal/pooling/test_intern_vit.py @@ -6,6 +6,7 @@ from huggingface_hub import snapshot_download from transformers import AutoConfig, AutoModel, CLIPImageProcessor from vllm.distributed import cleanup_dist_env_and_memory +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from ....conftest import ImageTestAssets @@ -14,6 +15,7 @@ from ....conftest import ImageTestAssets DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"] +@torch.inference_mode() def run_intern_vit_test( image_assets: ImageTestAssets, model_id: str, @@ -21,11 +23,12 @@ def run_intern_vit_test( dtype: str, ): model = snapshot_download(model_id, allow_patterns=DOWNLOAD_PATTERN) + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype] img_processor = CLIPImageProcessor.from_pretrained(model) images = [asset.pil_image for asset in image_assets] pixel_values = [ - img_processor(images, return_tensors='pt').pixel_values.to(dtype) + img_processor(images, return_tensors='pt').pixel_values.to(torch_dtype) for images in images ] @@ -34,7 +37,7 @@ def run_intern_vit_test( config.norm_type = "rms_norm" hf_model = AutoModel.from_pretrained(model, - torch_dtype=dtype, + torch_dtype=torch_dtype, trust_remote_code=True).to("cuda") hf_outputs_per_image = [ hf_model(pixel_value.to("cuda")).last_hidden_state @@ -48,7 +51,7 @@ def run_intern_vit_test( del hf_model cleanup_dist_env_and_memory() - vllm_model = vllm_model.to("cuda", dtype) + vllm_model = vllm_model.to("cuda", torch_dtype) vllm_outputs_per_image = [ vllm_model(pixel_values=pixel_value.to("cuda")) for pixel_value in pixel_values @@ -66,9 +69,8 @@ def run_intern_vit_test( "OpenGVLab/InternViT-300M-448px", "OpenGVLab/InternViT-6B-448px-V1-5", ]) -@pytest.mark.parametrize("dtype", [torch.half]) -@torch.inference_mode() -def test_models(image_assets, model_id, dtype: str) -> None: +@pytest.mark.parametrize("dtype", ["half"]) +def test_models(dist_init, image_assets, model_id, dtype: str) -> None: run_intern_vit_test( image_assets, model_id, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index cb0902c3a5b84..50da9679d5aae 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -497,10 +497,6 @@ class _AsyncLLMEngine(LLMEngine): prompt["prompt_token_ids"] = [0 ] * prompt["prompt_embeds"].shape[-2] - if self.tokenizer is not None: - tokenizer = await self.get_tokenizer_async(lora_request) - self._validate_token_prompt(prompt, tokenizer=tokenizer) - processed_inputs = await self.input_preprocessor.preprocess_async( prompt, lora_request=lora_request, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 142c8fe99b67e..4398852daac98 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -30,7 +30,7 @@ from vllm.entrypoints.openai.logits_processors import ( get_logits_processors as get_openai_logits_processors) from vllm.executor.executor_base import ExecutorBase from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs -from vllm.inputs.parse import is_token_prompt, split_enc_dec_inputs +from vllm.inputs.parse import split_enc_dec_inputs from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.logits_process import get_bad_words_logits_processors @@ -759,11 +759,6 @@ class LLMEngine: seq_len = prompt["prompt_embeds"].shape[0] prompt["prompt_token_ids"] = [0] * seq_len - if self.tokenizer is not None: - self._validate_token_prompt( - prompt, - tokenizer=self.get_tokenizer(lora_request=lora_request)) - processed_inputs = self.input_preprocessor.preprocess( prompt, tokenization_kwargs=tokenization_kwargs, @@ -782,27 +777,6 @@ class LLMEngine: priority=priority, ) - def _validate_token_prompt(self, prompt: PromptType, - tokenizer: AnyTokenizer): - # Guard against out-of-vocab tokens. - # For some tokenizers, tokenizer.decode will happily return empty text - # for token ids that are out of vocab, and we don't detect token ids - # that are greater than the max token id before running the model. - # However, these token ids will later crash a cuda kernel at runtime - # with an index out of bounds error. This will crash the entire engine. - # This needs to happen before multimodal input pre-processing, which - # may add dummy tokens that aren't part of the tokenizer's - # vocabulary. - if is_token_prompt(prompt): - prompt_ids = prompt["prompt_token_ids"] - if len(prompt_ids) == 0: - # Empty prompt check is handled later - return - max_input_id = max(prompt_ids) - if max_input_id > tokenizer.max_token_id: - raise ValueError( - "Token id {} is out of vocabulary".format(max_input_id)) - def _create_sequence_group_with_sampling( self, request_id: str, @@ -2049,6 +2023,12 @@ class LLMEngine: else: raise ValueError(f"The {prompt_type} prompt cannot be empty") + if tokenizer is not None: + max_input_id = max(prompt_ids, default=0) + if max_input_id > tokenizer.max_token_id: + raise ValueError( + f"Token id {max_input_id} is out of vocabulary") + max_prompt_len = self.model_config.max_model_len if len(prompt_ids) > max_prompt_len: if prompt_type == "encoder" and model_config.is_multimodal_model: diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 5632e8ad446df..e9350612ee57f 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -83,6 +83,9 @@ class EngineClient(ABC): else: processed_inputs = preprocessor._prompt_to_llm_inputs(prompt) + if processed_inputs["type"] == "embeds": + raise NotImplementedError + prompt_token_ids = processed_inputs["prompt_token_ids"] prompt_text = processed_inputs.get("prompt") multi_modal_data = processed_inputs.get("multi_modal_data") diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 0a302872d2633..69523f36ffc41 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -27,7 +27,7 @@ from vllm.entrypoints.score_utils import (_cosine_similarity, _validate_score_input_lens) from vllm.entrypoints.utils import _validate_truncation_size from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt -from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt +from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.guided_decoding.guided_fields import ( @@ -567,10 +567,12 @@ class LLM: mm_kwargs["mm_processor_kwargs"] = prompt[ "mm_processor_kwargs"] - if is_token_prompt(prompt): + if "prompt_token_ids" in prompt: + prompt = cast(TokensPrompt, prompt) # Needed for mypy prompt_tokens = prompt["prompt_token_ids"] else: prompt_tokens = tokenizer.encode(prompt["prompt"]) + instances.append( BeamSearchInstance(prompt_tokens, logprobs=None, **mm_kwargs)) diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 6a56d044c9f9c..86dbca1804126 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -70,6 +70,11 @@ class EmbedsPrompt(TypedDict): prompt_embeds: torch.Tensor """The embeddings of the prompt.""" + cache_salt: NotRequired[str] + """ + Optional cache salt to be used for prefix caching. + """ + SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt] """ @@ -195,13 +200,21 @@ class EmbedsInputs(TypedDict): prompt_embeds: torch.Tensor """The embeddings of the prompt.""" + cache_salt: NotRequired[str] + """ + Optional cache salt to be used for prefix caching. + """ -def embeds_inputs(prompt_embeds: torch.Tensor) -> EmbedsInputs: + +def embeds_inputs( + prompt_embeds: torch.Tensor, + cache_salt: Optional[str] = None, +) -> EmbedsInputs: """Construct :class:`EmbedsInputs` from optional values.""" - inputs = EmbedsInputs( - type="embeds", - prompt_embeds=prompt_embeds, - ) + inputs = EmbedsInputs(type="embeds", prompt_embeds=prompt_embeds) + + if cache_salt is not None: + inputs["cache_salt"] = cache_salt return inputs diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index 397344e402305..d17122b483446 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -6,9 +6,9 @@ from typing_extensions import TypeIs from vllm.utils import is_list_of -from .data import (EmbedsInputs, EmbedsPrompt, ExplicitEncoderDecoderPrompt, - ProcessorInputs, PromptType, SingletonInputs, - SingletonPrompt, TextPrompt, TokensPrompt) +from .data import (EmbedsPrompt, ExplicitEncoderDecoderPrompt, ProcessorInputs, + PromptType, SingletonInputs, SingletonPrompt, TextPrompt, + TokensPrompt) class ParsedText(TypedDict): @@ -90,6 +90,10 @@ class ParsedEmbedsPrompt(TypedDict): content: EmbedsPrompt +ParsedSingletonPrompt = Union[ParsedStrPrompt, ParsedTextPrompt, + ParsedTokensPrompt, ParsedEmbedsPrompt] + + @overload def parse_singleton_prompt(prompt: str) -> ParsedStrPrompt: ... @@ -110,10 +114,7 @@ def parse_singleton_prompt(prompt: EmbedsPrompt) -> ParsedEmbedsPrompt: ... -def parse_singleton_prompt( - prompt: SingletonPrompt, -) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt, - ParsedEmbedsPrompt]: +def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt: if isinstance(prompt, str): return ParsedStrPrompt(type="str", content=prompt) elif isinstance(prompt, dict): @@ -131,23 +132,11 @@ def parse_singleton_prompt( "inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt") -def is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]: - return isinstance(prompt, dict) and "prompt_token_ids" in prompt - - -def is_embeds_prompt(prompt: PromptType) -> TypeIs[EmbedsPrompt]: - return isinstance(prompt, dict) and "prompt_embeds" in prompt - - def is_explicit_encoder_decoder_prompt( prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]: return isinstance(prompt, dict) and "encoder_prompt" in prompt -def is_embeds_inputs(inputs: SingletonInputs) -> TypeIs[EmbedsInputs]: - return isinstance(inputs, dict) and inputs["type"] == "embeds" - - def split_enc_dec_inputs( inputs: ProcessorInputs, ) -> tuple[Optional[SingletonInputs], SingletonInputs]: diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 5a9e3643dcad3..97a2ce5c615e0 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -14,14 +14,14 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs, MultiModalInputs) from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import TokenizerGroup -from .data import (DecoderOnlyInputs, EmbedsInputs, EncoderDecoderInputs, - ProcessorInputs, PromptType, SingletonInputs, - SingletonPrompt, TokenInputs, embeds_inputs, token_inputs) -from .parse import (ParsedEmbedsPrompt, ParsedStrPrompt, ParsedTextPrompt, - ParsedTokensPrompt, is_embeds_inputs, - is_explicit_encoder_decoder_prompt, parse_singleton_prompt) +from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt, + EncoderDecoderInputs, ProcessorInputs, PromptType, + SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs, + TokensPrompt, embeds_inputs, token_inputs) +from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt logger = init_logger(__name__) @@ -140,13 +140,10 @@ class InputPreprocessor: """ Prepares `decoder_input_ids` for generation with encoder-decoder models. - Based on - - https://github.com/huggingface/transformers/blob/ - 4037a2b5b1278736e566aec12e169100275545ea/ - src/transformers/generation/utils.py - - specifically GenerationMixin._prepare_decoder_input_ids_for_generation() + Based on: + https://github.com/huggingface/transformers/blob/4037a2b5b1278736e566aec12e169100275545ea/src/transformers/generation/utils.py + specifically, + `GenerationMixin._prepare_decoder_input_ids_for_generation()`. Arguments: @@ -183,6 +180,23 @@ class InputPreprocessor: return prompt_token_ids + def _get_tokenization_kw( + self, + overrides: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: + kwargs = dict[str, Any]() + + if self.model_config.hf_config.model_type == "whisper": + # For Whisper, special tokens should be provided by the user based + # on the task and language of their request. Also needed to avoid + # appending an EOS token to the prompt which disrupts generation. + kwargs["add_special_tokens"] = False + + if overrides: + kwargs.update(overrides) + + return kwargs + def _tokenize_prompt( self, prompt: str, @@ -194,18 +208,11 @@ class InputPreprocessor: corresponding token IDs. """ tokenizer = self.get_tokenizer_group() - if tokenization_kwargs is None: - tokenization_kwargs = {} + tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs) - if self.model_config.hf_config.model_type == "whisper": - # For Whisper, special tokens should be provided by the user based - # on the task and language of their request. Also needed to avoid - # appending an EOS token to the prompt which disrupts generation. - tokenization_kwargs["add_special_tokens"] = False + encoder_config = self.model_config.encoder_config - if (self.model_config.encoder_config is not None - and self.model_config.encoder_config.get( - "do_lower_case", False)): + if encoder_config and encoder_config.get("do_lower_case", False): prompt = prompt.lower() return tokenizer.encode(prompt=prompt, @@ -220,18 +227,36 @@ class InputPreprocessor: ) -> list[int]: """Async version of :meth:`_tokenize_prompt`.""" tokenizer = self.get_tokenizer_group() - if tokenization_kwargs is None: - tokenization_kwargs = {} + tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs) - if self.model_config.hf_config.model_type == "whisper": - # For Whisper, special tokens should be provided by the user based - # on the task and language of their request. Also needed to avoid - # appending an EOS token to the prompt which disrupts generation. - tokenization_kwargs["add_special_tokens"] = False return await tokenizer.encode_async(prompt=prompt, lora_request=lora_request, **tokenization_kwargs) + def _get_mm_tokenizer( + self, + lora_request: Optional[LoRARequest], + ) -> AnyTokenizer: + # PrithviGeoSpatialMAE needs to be initialized without a tokenizer + # while using also multi-modal input + if not self.tokenizer: + return cast(AnyTokenizer, object()) # Dummy + + tokenizer_group = self.get_tokenizer_group() + return tokenizer_group.get_lora_tokenizer(lora_request) + + async def _get_mm_tokenizer_async( + self, + lora_request: Optional[LoRARequest], + ) -> AnyTokenizer: + # PrithviGeoSpatialMAE needs to be initialized without a tokenizer + # while using also multi-modal input + if not self.tokenizer: + return cast(AnyTokenizer, object()) # Dummy + + tokenizer_group = self.get_tokenizer_group() + return await tokenizer_group.get_lora_tokenizer_async(lora_request) + def _process_multimodal( self, prompt: Union[str, list[int]], @@ -244,13 +269,7 @@ class InputPreprocessor: Apply the model's multi-modal processor to a multi-modal prompt, returning the corresponding token IDs and metadata. """ - # At the moment on model (PrithviGeoSpatialMAE) requires to be - # initialized without a tokenizer while using also multi-modal input - if not self.tokenizer: - tokenizer = object() # Dummy - else: - tokenizer_group = self.get_tokenizer_group() - tokenizer = tokenizer_group.get_lora_tokenizer(lora_request) + tokenizer = self._get_mm_tokenizer(lora_request) mm_processor = self.mm_registry.create_processor(self.model_config, tokenizer=tokenizer) @@ -270,14 +289,7 @@ class InputPreprocessor: return_mm_hashes: bool = False, ) -> MultiModalInputs: """Async version of :meth:`_process_multimodal`.""" - # At the moment on model (PrithviGeoSpatialMAE) requires to be - # initialized without a tokenizer while using also multi-modal input - if not self.tokenizer: - tokenizer = object() # Dummy - else: - tokenizer_group = self.get_tokenizer_group() - tokenizer = await tokenizer_group.get_lora_tokenizer_async( - lora_request) + tokenizer = await self._get_mm_tokenizer_async(lora_request) mm_processor = self.mm_registry.create_processor(self.model_config, tokenizer=tokenizer) @@ -287,28 +299,160 @@ class InputPreprocessor: return mm_processor.apply(prompt, mm_data, mm_processor_kwargs, return_mm_hashes) - def _get_prompt_data(self, parsed_prompt: Union[ParsedStrPrompt, - ParsedTextPrompt, - ParsedTokensPrompt]): - prompt_text = None - prompt_token_ids = None - token_type_ids = None - cache_salt = None + def _process_embeds( + self, + parsed_content: EmbedsPrompt, + ) -> EmbedsInputs: + if envs.VLLM_USE_V1: + raise ValueError("prompt_embeds is only available in V0.") - if parsed_prompt["type"] == "str": - prompt_text = parsed_prompt["content"] + prompt_embeds = parsed_content["prompt_embeds"] + + # prompt_embeds must be (seq_len, hidden_size), but if the user + # passes in a batch of size 1, i.e. (1, seq_len, hidden_size), + # we can unambiguously process the intent by squeezing the batch + # dimension. + if prompt_embeds.ndim == 3: + prompt_embeds = prompt_embeds.squeeze(dim=0) + + if prompt_embeds.ndim != 2: + raise ValueError( + "prompt_embeds must be of shape (seq_len, hidden_size).") + + return embeds_inputs(prompt_embeds=prompt_embeds, + cache_salt=parsed_content.get("cache_salt")) + + async def _process_embeds_async( + self, + parsed_content: EmbedsPrompt, + ) -> EmbedsInputs: + return self._process_embeds(parsed_content) + + def _process_tokens( + self, + parsed_content: TokensPrompt, + lora_request: Optional[LoRARequest] = None, + return_mm_hashes: bool = False, + ) -> Union[TokenInputs, MultiModalInputs]: + prompt_token_ids = parsed_content["prompt_token_ids"] + token_type_ids = parsed_content.get("token_type_ids") + + inputs: Union[TokenInputs, MultiModalInputs] + if multi_modal_data := parsed_content.get("multi_modal_data"): + inputs = self._process_multimodal( + prompt_token_ids, + multi_modal_data, + parsed_content.get("mm_processor_kwargs"), + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, + ) else: - cache_salt = parsed_prompt["content"].get("cache_salt") - if parsed_prompt["type"] == "text": - prompt_text = parsed_prompt["content"]["prompt"] - elif parsed_prompt["type"] == "tokens": - prompt_token_ids = parsed_prompt["content"].get( - "prompt_token_ids") - token_type_ids = parsed_prompt["content"].get("token_type_ids") - else: - assert_never(parsed_prompt) + inputs = token_inputs( + prompt_token_ids=prompt_token_ids, + token_type_ids=token_type_ids, + ) - return prompt_text, prompt_token_ids, token_type_ids, cache_salt + if cache_salt := parsed_content.get("cache_salt"): + inputs["cache_salt"] = cache_salt + + return inputs + + async def _process_tokens_async( + self, + parsed_content: TokensPrompt, + lora_request: Optional[LoRARequest] = None, + return_mm_hashes: bool = False, + ) -> Union[TokenInputs, MultiModalInputs]: + prompt_token_ids = parsed_content["prompt_token_ids"] + token_type_ids = parsed_content.get("token_type_ids") + + inputs: Union[TokenInputs, MultiModalInputs] + if multi_modal_data := parsed_content.get("multi_modal_data"): + inputs = await self._process_multimodal_async( + prompt_token_ids, + multi_modal_data, + parsed_content.get("mm_processor_kwargs"), + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, + ) + else: + inputs = token_inputs( + prompt_token_ids=prompt_token_ids, + token_type_ids=token_type_ids, + ) + + if cache_salt := parsed_content.get("cache_salt"): + inputs["cache_salt"] = cache_salt + + return inputs + + def _process_text( + self, + parsed_content: TextPrompt, + tokenization_kwargs: Optional[dict[str, Any]] = None, + lora_request: Optional[LoRARequest] = None, + return_mm_hashes: bool = False, + ) -> Union[TokenInputs, MultiModalInputs]: + prompt_text = parsed_content["prompt"] + + inputs: Union[TokenInputs, MultiModalInputs] + if multi_modal_data := parsed_content.get("multi_modal_data"): + inputs = self._process_multimodal( + prompt_text, + multi_modal_data, + parsed_content.get("mm_processor_kwargs"), + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, + ) + else: + prompt_token_ids = self._tokenize_prompt( + prompt_text, + lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs, + ) + inputs = token_inputs( + prompt=prompt_text, + prompt_token_ids=prompt_token_ids, + ) + + if cache_salt := parsed_content.get("cache_salt"): + inputs["cache_salt"] = cache_salt + + return inputs + + async def _process_text_async( + self, + parsed_content: TextPrompt, + tokenization_kwargs: Optional[dict[str, Any]] = None, + lora_request: Optional[LoRARequest] = None, + return_mm_hashes: bool = False, + ) -> Union[TokenInputs, MultiModalInputs]: + prompt_text = parsed_content["prompt"] + + inputs: Union[TokenInputs, MultiModalInputs] + if multi_modal_data := parsed_content.get("multi_modal_data"): + inputs = await self._process_multimodal_async( + prompt_text, + multi_modal_data, + parsed_content.get("mm_processor_kwargs"), + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, + ) + else: + prompt_token_ids = await self._tokenize_prompt_async( + prompt_text, + lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs, + ) + inputs = token_inputs( + prompt=prompt_text, + prompt_token_ids=prompt_token_ids, + ) + + if cache_salt := parsed_content.get("cache_salt"): + inputs["cache_salt"] = cache_salt + + return inputs def _prompt_to_llm_inputs( self, @@ -333,38 +477,27 @@ class InputPreprocessor: parsed = parse_singleton_prompt(prompt) if parsed["type"] == "embeds": - return self._process_prompt_embeds(parsed) - - prompt_text, prompt_token_ids, token_type_ids, cache_salt = \ - self._get_prompt_data(parsed) - - # If multimodal data is present, process and return immediately - if parsed["type"] != "str" and parsed["content"].get( - "multi_modal_data") is not None: - inputs = self._process_multimodal( - prompt_text if prompt_text is not None else prompt_token_ids, - parsed["content"]["multi_modal_data"], - parsed["content"].get("mm_processor_kwargs"), + return self._process_embeds(parsed["content"]) + if parsed["type"] == "tokens": + return self._process_tokens( + parsed["content"], lora_request=lora_request, return_mm_hashes=return_mm_hashes, ) - if cache_salt is not None: - inputs["cache_salt"] = cache_salt - return inputs - - if prompt_token_ids is None: - prompt_token_ids = self._tokenize_prompt( - prompt_text, - lora_request=lora_request, + if parsed["type"] == "text": + return self._process_text( + parsed["content"], tokenization_kwargs=tokenization_kwargs, + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, + ) + if parsed["type"] == "str": + return self._process_text( + TextPrompt(prompt=parsed["content"]), + tokenization_kwargs=tokenization_kwargs, + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, ) - - return token_inputs( - prompt=prompt_text, - prompt_token_ids=prompt_token_ids, - token_type_ids=token_type_ids, - cache_salt=cache_salt, - ) assert_never(parsed) @@ -375,79 +508,49 @@ class InputPreprocessor: lora_request: Optional[LoRARequest] = None, return_mm_hashes: bool = False, ) -> SingletonInputs: - """Async version of :meth:`_extract_prompt_components`.""" + """Async version of :meth:`_prompt_to_llm_inputs`.""" parsed = parse_singleton_prompt(prompt) if parsed["type"] == "embeds": - return self._process_prompt_embeds(parsed) - - prompt_text, prompt_token_ids, token_type_ids, cache_salt = \ - self._get_prompt_data(parsed) - - if parsed["type"] != "str" and parsed["content"].get( - "multi_modal_data") is not None: - inputs = await self._process_multimodal_async( - prompt_token_ids if prompt_text is None else prompt_text, - parsed["content"]["multi_modal_data"], - parsed["content"].get("mm_processor_kwargs"), + return await self._process_embeds_async(parsed["content"]) + if parsed["type"] == "tokens": + return await self._process_tokens_async( + parsed["content"], lora_request=lora_request, return_mm_hashes=return_mm_hashes, ) - if cache_salt is not None: - inputs["cache_salt"] = cache_salt - return inputs - - if prompt_token_ids is None: - prompt_token_ids = await self._tokenize_prompt_async( - prompt_text, - lora_request=lora_request, + if parsed["type"] == "text": + return await self._process_text_async( + parsed["content"], tokenization_kwargs=tokenization_kwargs, + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, + ) + if parsed["type"] == "str": + return await self._process_text_async( + TextPrompt(prompt=parsed["content"]), + tokenization_kwargs=tokenization_kwargs, + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, ) - - return token_inputs( - prompt=prompt_text, - prompt_token_ids=prompt_token_ids, - token_type_ids=token_type_ids, - cache_salt=cache_salt, - ) - - def _process_prompt_embeds(self, - parsed: ParsedEmbedsPrompt) -> EmbedsInputs: - if envs.VLLM_USE_V1: - raise ValueError("prompt_embeds is only available in V0.") - - prompt_embeds_content = parsed["content"] - - prompt_embeds = prompt_embeds_content["prompt_embeds"] - - # prompt_embeds must be (seq_len, hidden_size), but if the user - # passes in a batch of size 1, i.e. (1, seq_len, hidden_size), - # we can unambiguously process the intent by squeezing the batch - # dimension. - if prompt_embeds.ndim == 3 and prompt_embeds.shape[0] == 1: - prompt_embeds = prompt_embeds.squeeze(dim=0) - - if prompt_embeds.ndim != 2: - raise ValueError( - "prompt_embeds must be of shape (seq_len, hidden_size).") - - return embeds_inputs(prompt_embeds=prompt_embeds) assert_never(parsed) def _build_enc_dec_llm_inputs( self, - encoder_inputs: Union[TokenInputs, MultiModalInputs], - decoder_inputs: Optional[Union[TokenInputs, MultiModalInputs]], + encoder_inputs: SingletonInputs, + decoder_inputs: Optional[SingletonInputs], ) -> EncoderDecoderInputs: - if (encoder_inputs["type"] == "token" - or encoder_inputs["type"] == "multimodal"): - pass - else: - assert_never(encoder_inputs) # type: ignore[arg-type] + if (encoder_inputs["type"] == "embeds" + or decoder_inputs and decoder_inputs["type"] == "embeds"): + raise ValueError("Embedding inputs are not supported for encoder-" + "decoder models") - # Mypy does not correctly infer that EmbedsInputs is impossible - assert "prompt_token_ids" in encoder_inputs + # Needed for mypy + encoder_inputs = cast(Union[TokenInputs, MultiModalInputs], + encoder_inputs) + decoder_inputs = cast(Optional[Union[TokenInputs, MultiModalInputs]], + decoder_inputs) if decoder_inputs is None: if self.model_config.hf_config.model_type == "whisper": @@ -460,74 +563,78 @@ class InputPreprocessor: dec_token_ids = self._prepare_decoder_input_ids_for_generation( None) decoder_inputs = token_inputs(dec_token_ids) - elif (decoder_inputs["type"] == "token" - or decoder_inputs["type"] == "multimodal"): - dec_token_ids = self._prepare_decoder_input_ids_for_generation( - decoder_inputs["prompt_token_ids"]) - decoder_inputs["prompt_token_ids"] = dec_token_ids - + else: if "multi_modal_data" in decoder_inputs: raise ValueError("Multi-modal decoder inputs of encoder-" "decoder models are not supported yet") - else: - assert_never(encoder_inputs) # type: ignore[arg-type] + + dec_token_ids = self._prepare_decoder_input_ids_for_generation( + decoder_inputs["prompt_token_ids"]) + decoder_inputs["prompt_token_ids"] = dec_token_ids return EncoderDecoderInputs( encoder=encoder_inputs, decoder=decoder_inputs, ) - def _separate_enc_dec_inputs_from_mm_processor_outputs( + def _split_enc_dec_mm_inputs( self, - inputs: SingletonInputs, - decoder_inputs_to_override: Optional[Union[TokenInputs, - MultiModalInputs]] = None, + inputs: Union[SingletonInputs, MultiModalEncDecInputs], + decoder_inputs_to_override: Optional[SingletonInputs] = None, ) -> tuple[SingletonInputs, SingletonInputs]: """ For encoder/decoder models only: Separate Encoder/Decoder inputs from a MultiModalEncDecInputs """ + if (inputs["type"] == "embeds" or decoder_inputs_to_override + and decoder_inputs_to_override["type"] == "embeds"): + raise ValueError("Embedding inputs are not supported for encoder-" + "decoder models") + + # Needed for mypy + inputs = cast( + Union[TokenInputs, MultiModalInputs, MultiModalEncDecInputs], + inputs, + ) + decoder_inputs_to_override = cast( + Optional[Union[TokenInputs, MultiModalInputs]], + decoder_inputs_to_override, + ) + encoder_inputs: SingletonInputs decoder_inputs: SingletonInputs - if inputs["type"] == "multimodal": - # Multimodal data inputs - assert ("encoder_prompt" in inputs - and "encoder_prompt_token_ids" in inputs) + + if inputs["type"] == "multimodal": # Multimodal data inputs + if not ("encoder_prompt" in inputs + and "encoder_prompt_token_ids" in inputs): + raise RuntimeError("You should register an encoder-decoder " + "multi-modal processor for encoder-decoder " + "models.") inputs = cast(MultiModalEncDecInputs, inputs) + encoder_inputs = token_inputs( prompt=inputs["encoder_prompt"], prompt_token_ids=inputs["encoder_prompt_token_ids"], ) - if decoder_inputs_to_override is not None: - decoder_inputs = MultiModalInputs( - type="multimodal", - prompt=decoder_inputs_to_override.get("prompt", ""), - prompt_token_ids=decoder_inputs_to_override[ - "prompt_token_ids"], - mm_kwargs=inputs["mm_kwargs"], - mm_hashes=inputs["mm_hashes"], - mm_placeholders=inputs["mm_placeholders"], - ) - else: - decoder_inputs = MultiModalInputs( - type="multimodal", - prompt=inputs["prompt"], - prompt_token_ids=inputs["prompt_token_ids"], - mm_kwargs=inputs["mm_kwargs"], - mm_hashes=inputs["mm_hashes"], - mm_placeholders=inputs["mm_placeholders"], - ) - cache_salt = inputs.get("cache_salt") - if cache_salt is not None: + decoder_prompt_inputs = decoder_inputs_to_override or inputs + decoder_inputs = MultiModalInputs( + type="multimodal", + prompt=decoder_prompt_inputs.get("prompt", ""), + prompt_token_ids=decoder_prompt_inputs["prompt_token_ids"], + mm_kwargs=inputs["mm_kwargs"], + mm_hashes=inputs["mm_hashes"], + mm_placeholders=inputs["mm_placeholders"], + ) + if cache_salt := inputs.get("cache_salt"): decoder_inputs["cache_salt"] = cache_salt - elif inputs["type"] == "token": - # Text-only inputs + elif inputs["type"] == "token": # Text-only inputs encoder_inputs = token_inputs(prompt="", prompt_token_ids=[]) decoder_inputs = decoder_inputs_to_override or inputs else: assert_never(inputs) # type: ignore[arg-type] + return encoder_inputs, decoder_inputs def _process_encoder_decoder_prompt( @@ -580,11 +687,9 @@ class InputPreprocessor: # For multimodal model, override decoder prompt from processor # with explicit decoder prompt. if self.model_config.is_multimodal_model: - assert decoder_inputs is None or not is_embeds_inputs( - decoder_inputs) encoder_inputs, decoder_inputs = ( - self._separate_enc_dec_inputs_from_mm_processor_outputs( - encoder_inputs, decoder_inputs)) + self._split_enc_dec_mm_inputs(encoder_inputs, + decoder_inputs)) else: inputs = self._prompt_to_llm_inputs( prompt, @@ -593,16 +698,11 @@ class InputPreprocessor: if self.model_config.is_multimodal_model: # Encoder-Decoder Multimodal model encoder_inputs, decoder_inputs = ( - self._separate_enc_dec_inputs_from_mm_processor_outputs( - inputs)) + self._split_enc_dec_mm_inputs(inputs)) else: encoder_inputs = inputs decoder_inputs = None - # Mypy does not do type inference well with TypedDicts with Literal - # values. - assert not is_embeds_inputs(encoder_inputs) - assert decoder_inputs is None or not is_embeds_inputs(decoder_inputs) return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs) async def _process_encoder_decoder_prompt_async( @@ -635,11 +735,9 @@ class InputPreprocessor: # For multimodal model, override decoder prompt from processor # with explicit decoder prompt. if self.model_config.is_multimodal_model: - assert decoder_inputs is None or not is_embeds_inputs( - decoder_inputs) encoder_inputs, decoder_inputs = ( - self._separate_enc_dec_inputs_from_mm_processor_outputs( - encoder_inputs, decoder_inputs)) + self._split_enc_dec_mm_inputs(encoder_inputs, + decoder_inputs)) else: inputs = await self._prompt_to_llm_inputs_async( prompt, @@ -648,16 +746,11 @@ class InputPreprocessor: if self.model_config.is_multimodal_model: # Encoder-Decoder Multimodal model encoder_inputs, decoder_inputs = ( - self._separate_enc_dec_inputs_from_mm_processor_outputs( - inputs)) + self._split_enc_dec_mm_inputs(inputs)) else: encoder_inputs = inputs decoder_inputs = None - # Mypy does not do type inference well with TypedDicts with Literal - # values. - assert not is_embeds_inputs(encoder_inputs) - assert decoder_inputs is None or not is_embeds_inputs(decoder_inputs) return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs) def _build_decoder_only_llm_inputs( @@ -665,19 +758,13 @@ class InputPreprocessor: prompt_inputs: DecoderOnlyInputs, prompt_adapter_request: Optional[PromptAdapterRequest], ) -> DecoderOnlyInputs: - if (prompt_inputs["type"] == "token" - or prompt_inputs["type"] == "multimodal"): - # Mypy does not do type inference well with typedicts and Literal - # values - assert not is_embeds_inputs(prompt_inputs) + if "prompt_token_ids" in prompt_inputs: + prompt_inputs = cast(Union[TokenInputs, MultiModalInputs], + prompt_inputs) # Needed for mypy prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter( prompt_inputs["prompt_token_ids"], prompt_adapter_request=prompt_adapter_request, ) - elif (prompt_inputs["type"] == "embeds"): - pass - else: - assert_never(prompt_inputs) # type: ignore[arg-type] return prompt_inputs diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index e8745a8f1f901..58168d0e850c2 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1670,15 +1670,17 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): placeholders = mm_placeholders.get(modality, []) if len(placeholders) != item_count: + # NOTE: If you are a model developer, this can also arise from + # an inconsistency between `_call_hf_processor` and + # `_get_mm_fields_config` implementations raise RuntimeError( f"Expected there to be {item_count} prompt updates " f"corresponding to {item_count} {modality} items, but " f"instead found {len(placeholders)} prompt updates! " - "Either the prompt text has missing/incorrect tokens for " - "multi-modal inputs, or there is a problem with your " - "implementation of merged multi-modal processor for this " - "model (usually arising from an inconsistency between " - "`_call_hf_processor` and `_get_prompt_updates`).") + "This is likely because you forgot to include input " + "placeholder tokens (e.g., ``, `<|image_pad|>`) " + "in the prompt. If the model has a chat template, make " + "sure you have applied it before calling `LLM.generate`.") def _maybe_apply_prompt_updates( self,