From ae03f4c010a7275bbc6c816ac710cbe8e7cd87b1 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 3 Oct 2025 15:23:21 +0800 Subject: [PATCH] [Input] Remove unused `prompt` field (#26097) Signed-off-by: DarkLight1337 Signed-off-by: yewentao256 --- .../processing/test_transformers.py | 3 +- vllm/engine/protocol.py | 11 ++++-- vllm/inputs/data.py | 8 ---- vllm/inputs/preprocess.py | 37 +++++++++---------- vllm/model_executor/models/llava.py | 3 +- vllm/model_executor/models/paligemma.py | 1 - vllm/model_executor/models/phi3v.py | 10 ++--- .../models/qwen2_5_omni_thinker.py | 24 ++++-------- vllm/model_executor/models/terratorch.py | 1 - vllm/model_executor/models/transformers.py | 1 - vllm/multimodal/inputs.py | 6 --- vllm/multimodal/processing.py | 28 ++++---------- vllm/v1/engine/async_llm.py | 14 ++++--- vllm/v1/engine/llm_engine.py | 13 ++++--- vllm/v1/engine/processor.py | 8 +--- 15 files changed, 67 insertions(+), 101 deletions(-) diff --git a/tests/models/multimodal/processing/test_transformers.py b/tests/models/multimodal/processing/test_transformers.py index 54a0be99384a8..c0e043ade736a 100644 --- a/tests/models/multimodal/processing/test_transformers.py +++ b/tests/models/multimodal/processing/test_transformers.py @@ -37,4 +37,5 @@ def test_multimodal_processor(model_id): hf_processor_mm_kwargs={}, ) - assert str_processed_inputs["prompt"] == ids_processed_inputs["prompt"] + assert (str_processed_inputs["prompt_token_ids"] + == ids_processed_inputs["prompt_token_ids"]) diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 9aea74d0c8f3c..997c99af24089 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -94,10 +94,15 @@ class EngineClient(ABC): # this happens again in generation, so the double expansion causes # a mismatch. # TODO - would be ideal to handle this more gracefully. - prompt_token_ids = prompt.get("prompt_token_ids") - multi_modal_data = prompt.get("multi_modal_data") + if isinstance(prompt, str): + prompt_text = prompt + prompt_token_ids = [] + multi_modal_data = None + else: + prompt_text = prompt.get("prompt") + prompt_token_ids = prompt.get("prompt_token_ids", []) + multi_modal_data = prompt.get("multi_modal_data") - prompt_text = processed_inputs.get("prompt") mm_processor_kwargs = processed_inputs.get("mm_processor_kwargs") tokenized_length = len(prompt_token_ids) diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 1718c0767ab68..562e73eead66a 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -205,11 +205,6 @@ class TokenInputs(TypedDict): prompt_token_ids: list[int] """The token IDs of the prompt.""" - prompt: NotRequired[str] - """ - The original prompt text corresponding to the token IDs, if available. - """ - cache_salt: NotRequired[str] """ Optional cache salt to be used for prefix caching. @@ -218,15 +213,12 @@ class TokenInputs(TypedDict): def token_inputs( prompt_token_ids: list[int], - prompt: Optional[str] = None, cache_salt: Optional[str] = None, ) -> TokenInputs: """Construct [`TokenInputs`][vllm.inputs.data.TokenInputs] from optional values.""" inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids) - if prompt is not None: - inputs["prompt"] = prompt if cache_salt is not None: inputs["cache_salt"] = cache_salt diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 7518cd8fc897f..65460b46cb5a6 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -16,9 +16,10 @@ from vllm.multimodal.processing import BaseMultiModalProcessor from vllm.transformers_utils.tokenizer import AnyTokenizer from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt, - EncoderDecoderInputs, ProcessorInputs, PromptType, - SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs, - TokensPrompt, embeds_inputs, token_inputs) + EncoderDecoderInputs, ExplicitEncoderDecoderPrompt, + ProcessorInputs, PromptType, SingletonInputs, + SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt, + embeds_inputs, token_inputs) from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt logger = init_logger(__name__) @@ -322,7 +323,7 @@ class InputPreprocessor: mm_uuids=mm_uuids, ) else: - inputs = token_inputs(prompt_token_ids=prompt_token_ids) + inputs = token_inputs(prompt_token_ids) if cache_salt := parsed_content.get("cache_salt"): inputs["cache_salt"] = cache_salt @@ -352,10 +353,7 @@ class InputPreprocessor: prompt_text, tokenization_kwargs=tokenization_kwargs, ) - inputs = token_inputs( - prompt=prompt_text, - prompt_token_ids=prompt_token_ids, - ) + inputs = token_inputs(prompt_token_ids) if cache_salt := parsed_content.get("cache_salt"): inputs["cache_salt"] = cache_salt @@ -473,22 +471,17 @@ class InputPreprocessor: decoder_inputs: SingletonInputs if inputs["type"] == "multimodal": # Multimodal data inputs - if not ("encoder_prompt" in inputs - and "encoder_prompt_token_ids" in inputs): + if "encoder_prompt_token_ids" not in inputs: raise RuntimeError("You should register an encoder-decoder " "multi-modal processor for encoder-decoder " "models.") inputs = cast(MultiModalEncDecInputs, inputs) - encoder_inputs = token_inputs( - prompt=inputs["encoder_prompt"], - prompt_token_ids=inputs["encoder_prompt_token_ids"], - ) + encoder_inputs = token_inputs(inputs["encoder_prompt_token_ids"]) decoder_prompt_inputs = decoder_inputs_to_override or inputs decoder_inputs = MultiModalInputs( type="multimodal", - prompt=decoder_prompt_inputs.get("prompt", ""), prompt_token_ids=decoder_prompt_inputs["prompt_token_ids"], mm_kwargs=inputs["mm_kwargs"], mm_hashes=inputs["mm_hashes"], @@ -498,7 +491,7 @@ class InputPreprocessor: decoder_inputs["cache_salt"] = cache_salt elif inputs["type"] == "token": # Text-only inputs - encoder_inputs = token_inputs(prompt="", prompt_token_ids=[]) + encoder_inputs = token_inputs(prompt_token_ids=[]) decoder_inputs = decoder_inputs_to_override or inputs else: assert_never(inputs) # type: ignore[arg-type] @@ -549,12 +542,14 @@ class InputPreprocessor: decoder_inputs: Optional[SingletonInputs] if is_explicit_encoder_decoder_prompt(prompt): + # `cast` is needed for mypy, but not pyright + prompt_ = cast(ExplicitEncoderDecoderPrompt, prompt) encoder_inputs = self._prompt_to_llm_inputs( - prompt["encoder_prompt"], + prompt_["encoder_prompt"], tokenization_kwargs=tokenization_kwargs, mm_uuids=mm_uuids, ) - if (decoder_input := prompt["decoder_prompt"]) is None: + if (decoder_input := prompt_["decoder_prompt"]) is None: decoder_inputs = None else: decoder_inputs = self._prompt_to_llm_inputs(decoder_input) @@ -565,8 +560,9 @@ class InputPreprocessor: self._split_enc_dec_mm_inputs(encoder_inputs, decoder_inputs)) else: + # `cast` is needed for mypy, but not pyright inputs = self._prompt_to_llm_inputs( - prompt, + cast(SingletonPrompt, prompt), tokenization_kwargs=tokenization_kwargs, mm_uuids=mm_uuids, ) @@ -641,8 +637,9 @@ class InputPreprocessor: "to decoder-only models") # Decoder-only operation + # `cast` is needed for mypy, but not pyright return self._process_decoder_only_prompt( - prompt, + cast(SingletonPrompt, prompt), tokenization_kwargs=tokenization_kwargs, mm_uuids=mm_uuids, ) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 78c413b770516..9f338f2ae3fb8 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -778,7 +778,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): ) ], mm_item_counts) - prompt_ids, prompt, _ = self._apply_prompt_updates( + prompt_ids, _ = self._apply_prompt_updates( result["prompt_token_ids"], mantis_mm_repls, ) @@ -798,7 +798,6 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): return MultiModalInputs( type="multimodal", - prompt=prompt, prompt_token_ids=prompt_ids, mm_kwargs=mm_kwargs, mm_hashes=mm_hashes, diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index d118e6c89ab56..d7108a5e5feb8 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -219,7 +219,6 @@ class PaliGemmaMultiModalProcessor( if len(prompt_token_ids) and prompt_token_ids[-1] != newline_token_id: prompt_token_ids.append(newline_token_id) mm_inputs["prompt_token_ids"] = prompt_token_ids - mm_inputs["prompt"] += newline_prompt return mm_inputs diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 59977796e2af9..a6baeaa526e73 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -461,7 +461,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): self, token_ids: list[int], mm_prompt_updates: MultiModalPromptUpdates, - ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: + ) -> tuple[list[int], Mapping[str, list[PlaceholderFeaturesInfo]]]: # align to hf behavior when there are images if len(mm_prompt_updates): tokenizer = self.info.get_tokenizer() @@ -496,14 +496,14 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): for ele in sublist for e in ele ] - token_ids, text, placeholders = super()._apply_prompt_updates( + token_ids, placeholders = super()._apply_prompt_updates( token_ids=token_ids, mm_prompt_updates=mm_prompt_updates, ) # Keep the behavior in line with HF processor - if text.startswith(" <|image|>"): - text = text.replace(" <|image|>", "<|image|>", 1) + if token_ids[:2] == tokenizer.encode(" <|image|>", + add_special_tokens=False): token_ids = [token_ids[0], *token_ids[2:]] placeholders = { modality: [ @@ -518,7 +518,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): for modality, ps in placeholders.items() } - return token_ids, text, placeholders + return token_ids, placeholders @MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor, diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index b5e82c9b21cd1..af0a97e3c8676 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -63,7 +63,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.tokenizer import decode_tokens, encode_tokens +from vllm.transformers_utils.tokenizer import encode_tokens from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, @@ -316,7 +316,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor( mm_kwargs: MultiModalKwargsItems, mm_prompt_updates: MultiModalPromptUpdates, is_update_applied: bool, - ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: + ) -> tuple[list[int], Mapping[str, list[PlaceholderFeaturesInfo]]]: """ Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`. """ @@ -341,28 +341,20 @@ class Qwen2_5OmniThinkerMultiModalProcessor( self._validate_mm_placeholders( mm_placeholders, mm_item_counts, - use_audio_in_video=use_audio_in_video) - - tokenizer = self.info.get_tokenizer() - prompt = decode_tokens(tokenizer, prompt_ids) + use_audio_in_video=use_audio_in_video, + ) else: - ( - prompt_ids, - prompt, - mm_placeholders, - ) = self._apply_prompt_updates( + prompt_ids, mm_placeholders = self._apply_prompt_updates( prompt_ids, mm_prompt_updates, ) self._validate_mm_placeholders( mm_placeholders, mm_item_counts, - use_audio_in_video=use_audio_in_video) + use_audio_in_video=use_audio_in_video, + ) - tokenizer = self.info.get_tokenizer() - prompt = decode_tokens(tokenizer, prompt_ids) - - return prompt_ids, prompt, mm_placeholders + return prompt_ids, mm_placeholders def _get_prompt_updates( self, diff --git a/vllm/model_executor/models/terratorch.py b/vllm/model_executor/models/terratorch.py index 938b02e3e04b3..5082054596d85 100644 --- a/vllm/model_executor/models/terratorch.py +++ b/vllm/model_executor/models/terratorch.py @@ -190,7 +190,6 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor): return MultiModalInputs( type="multimodal", - prompt=prompt, prompt_token_ids=[1], mm_kwargs=mm_kwargs, mm_hashes=mm_hashes, diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 18a0dafd001d8..fffdbd00babbf 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -453,7 +453,6 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): return MultiModalInputs( type="multimodal", - prompt=prompt, prompt_token_ids=prompt_ids, mm_kwargs=mm_kwargs, mm_hashes=mm_hashes, diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 3539517ed45ee..14d0c8dda78e0 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -949,9 +949,6 @@ class MultiModalInputs(TypedDict): type: Literal["multimodal"] """The type of inputs.""" - prompt: str - """The processed prompt text.""" - prompt_token_ids: list[int] """The processed token IDs which includes placeholder tokens.""" @@ -980,8 +977,5 @@ class MultiModalEncDecInputs(MultiModalInputs): ready to be passed to vLLM internals. """ - encoder_prompt: str - """The processed encoder prompt text.""" - encoder_prompt_token_ids: list[int] """The processed token IDs of the encoder prompt.""" diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index ce671479b1ae7..bc998dc2785f0 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1878,7 +1878,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): self, token_ids: list[int], mm_prompt_updates: MultiModalPromptUpdates, - ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: + ) -> tuple[list[int], Mapping[str, list[PlaceholderFeaturesInfo]]]: tokenizer = self.info.get_tokenizer() new_token_ids, match_result = self._apply_token_matches( @@ -1896,11 +1896,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): # Since it is inefficient to search for all possible tokenizations # of the search text in the prompt, we instead perform string-based # updates on the decoded token IDs, then encode them back. - if all( + if not all( all(update_idx is not None for update_idx in update_idxs) for update_idxs in match_result.values()): - new_text = decode_tokens(tokenizer, new_token_ids) - else: new_text, match_result = self._apply_text_matches( decode_tokens(tokenizer, token_ids), mm_prompt_updates, @@ -1928,7 +1926,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): dict(matched_updates), ) - return new_token_ids, new_text, placeholders + return new_token_ids, placeholders def _validate_mm_kwargs( self, @@ -1976,7 +1974,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_kwargs: MultiModalKwargsOptionalItems, mm_prompt_updates: MultiModalPromptUpdates, is_update_applied: bool, - ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: + ) -> tuple[list[int], Mapping[str, list[PlaceholderFeaturesInfo]]]: mm_item_counts = mm_items.get_all_counts() self._validate_mm_kwargs(mm_kwargs, mm_item_counts) @@ -1986,21 +1984,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_prompt_updates, ) self._validate_mm_placeholders(mm_placeholders, mm_item_counts) - - tokenizer = self.info.get_tokenizer() - prompt = decode_tokens(tokenizer, prompt_ids) else: - ( - prompt_ids, - prompt, - mm_placeholders, - ) = self._apply_prompt_updates( + prompt_ids, mm_placeholders = self._apply_prompt_updates( prompt_ids, mm_prompt_updates, ) self._validate_mm_placeholders(mm_placeholders, mm_item_counts) - return prompt_ids, prompt, mm_placeholders + return prompt_ids, mm_placeholders def apply( self, @@ -2042,7 +2033,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ) # NOTE: tokenization_kwargs are not required to init processor - prompt_ids, prompt, mm_placeholders = self._maybe_apply_prompt_updates( + prompt_ids, mm_placeholders = self._maybe_apply_prompt_updates( mm_items=mm_items, prompt_ids=prompt_ids, mm_kwargs=mm_info.kwargs, @@ -2057,7 +2048,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): return MultiModalInputs( type="multimodal", - prompt=prompt, prompt_token_ids=prompt_ids, mm_kwargs=mm_info.kwargs, mm_hashes=mm_info.hashes, @@ -2100,19 +2090,15 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): tokenizer = self.info.get_tokenizer() decoder_prompt_raw = self.create_decoder_prompt(prompt, mm_data) if isinstance(decoder_prompt_raw, str): - decoder_prompt = decoder_prompt_raw decoder_prompt_ids = encode_tokens(tokenizer, decoder_prompt_raw, add_special_tokens=False) else: - decoder_prompt = decode_tokens(tokenizer, decoder_prompt_raw) decoder_prompt_ids = decoder_prompt_raw mm_inputs = MultiModalEncDecInputs( - encoder_prompt=encoder_inputs["prompt"], encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"], **encoder_inputs) - mm_inputs["prompt"] = decoder_prompt mm_inputs["prompt_token_ids"] = decoder_prompt_ids return mm_inputs diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 36d0d50bf23db..e88b4c5346c30 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -281,12 +281,16 @@ class AsyncLLM(EngineClient): queue = RequestOutputCollector(output_kind=params.output_kind) # Convert Input --> Request. - prompt_str, request = self.processor.process_inputs( - request_id, prompt, params, arrival_time, lora_request, - tokenization_kwargs, trace_headers, priority, data_parallel_rank) + request = self.processor.process_inputs(request_id, prompt, params, + arrival_time, lora_request, + tokenization_kwargs, + trace_headers, priority, + data_parallel_rank) + prompt_text = prompt if isinstance(prompt, + str) else prompt.get("prompt") if is_pooling or params.n == 1: - await self._add_request(request, prompt_str, None, 0, queue) + await self._add_request(request, prompt_text, None, 0, queue) return queue # Get the updated SamplingParams from the request, which @@ -302,7 +306,7 @@ class AsyncLLM(EngineClient): request) child_request.request_id = request_id child_request.sampling_params = child_params - await self._add_request(child_request, prompt_str, parent_request, + await self._add_request(child_request, prompt_text, parent_request, idx, queue) return queue diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 2738776e3d37c..f81427161d7d7 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -227,15 +227,18 @@ class LLMEngine: f"request_id must be a string, got {type(request_id)}") # Process raw inputs into the request. - prompt_str, request = self.processor.process_inputs( - request_id, prompt, params, arrival_time, lora_request, - tokenization_kwargs, trace_headers, priority) + request = self.processor.process_inputs(request_id, prompt, params, + arrival_time, lora_request, + tokenization_kwargs, + trace_headers, priority) + prompt_text = prompt if isinstance(prompt, + str) else prompt.get("prompt") n = params.n if isinstance(params, SamplingParams) else 1 if n == 1: # Make a new RequestState and queue. - self.output_processor.add_request(request, prompt_str, None, 0) + self.output_processor.add_request(request, prompt_text, None, 0) # Add the request to EngineCore. self.engine_core.add_request(request) return @@ -249,7 +252,7 @@ class LLMEngine: child_request.sampling_params = params # Make a new RequestState and queue. - self.output_processor.add_request(child_request, prompt_str, + self.output_processor.add_request(child_request, prompt_text, parent_req, idx) # Add the request to EngineCore. self.engine_core.add_request(child_request) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 843ca9ad68e38..c30ceb96a5e07 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -334,9 +334,7 @@ class Processor: trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, data_parallel_rank: Optional[int] = None, - ) -> tuple[Optional[str], EngineCoreRequest]: - - # TODO(woosuk): Support pooling models. + ) -> EngineCoreRequest: self._validate_lora(lora_request) self._validate_params(params) @@ -395,8 +393,6 @@ class Processor: # discriminated unions of TypedDicts, because of how it handles # inheritance of TypedDict. If we explicitly extract the items we want # we can avoid type errors from using `dict.get` later in the method. - prompt_str: Optional[str] = None if decoder_inputs[ - "type"] == "embeds" else decoder_inputs.get("prompt") prompt_token_ids = decoder_inputs[ "prompt_token_ids"] if decoder_inputs["type"] != "embeds" else None prompt_embeds = decoder_inputs["prompt_embeds"] if decoder_inputs[ @@ -442,7 +438,7 @@ class Processor: identifier=decoder_mm_hashes[modality][idx], mm_position=decoder_mm_positions[modality][idx])) - return prompt_str, EngineCoreRequest( + return EngineCoreRequest( request_id=request_id, prompt_token_ids=prompt_token_ids, prompt_embeds=prompt_embeds,