[Input] Remove unused prompt field (#26097)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Cyrus Leung 2025-10-03 15:23:21 +08:00 committed by yewentao256
parent 7e4b1861c3
commit ae03f4c010
15 changed files with 67 additions and 101 deletions

View File

@ -37,4 +37,5 @@ def test_multimodal_processor(model_id):
hf_processor_mm_kwargs={},
)
assert str_processed_inputs["prompt"] == ids_processed_inputs["prompt"]
assert (str_processed_inputs["prompt_token_ids"]
== ids_processed_inputs["prompt_token_ids"])

View File

@ -94,10 +94,15 @@ class EngineClient(ABC):
# this happens again in generation, so the double expansion causes
# a mismatch.
# TODO - would be ideal to handle this more gracefully.
prompt_token_ids = prompt.get("prompt_token_ids")
multi_modal_data = prompt.get("multi_modal_data")
if isinstance(prompt, str):
prompt_text = prompt
prompt_token_ids = []
multi_modal_data = None
else:
prompt_text = prompt.get("prompt")
prompt_token_ids = prompt.get("prompt_token_ids", [])
multi_modal_data = prompt.get("multi_modal_data")
prompt_text = processed_inputs.get("prompt")
mm_processor_kwargs = processed_inputs.get("mm_processor_kwargs")
tokenized_length = len(prompt_token_ids)

View File

@ -205,11 +205,6 @@ class TokenInputs(TypedDict):
prompt_token_ids: list[int]
"""The token IDs of the prompt."""
prompt: NotRequired[str]
"""
The original prompt text corresponding to the token IDs, if available.
"""
cache_salt: NotRequired[str]
"""
Optional cache salt to be used for prefix caching.
@ -218,15 +213,12 @@ class TokenInputs(TypedDict):
def token_inputs(
prompt_token_ids: list[int],
prompt: Optional[str] = None,
cache_salt: Optional[str] = None,
) -> TokenInputs:
"""Construct [`TokenInputs`][vllm.inputs.data.TokenInputs] from optional
values."""
inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids)
if prompt is not None:
inputs["prompt"] = prompt
if cache_salt is not None:
inputs["cache_salt"] = cache_salt

View File

@ -16,9 +16,10 @@ from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.transformers_utils.tokenizer import AnyTokenizer
from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
EncoderDecoderInputs, ProcessorInputs, PromptType,
SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs,
TokensPrompt, embeds_inputs, token_inputs)
EncoderDecoderInputs, ExplicitEncoderDecoderPrompt,
ProcessorInputs, PromptType, SingletonInputs,
SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt,
embeds_inputs, token_inputs)
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
logger = init_logger(__name__)
@ -322,7 +323,7 @@ class InputPreprocessor:
mm_uuids=mm_uuids,
)
else:
inputs = token_inputs(prompt_token_ids=prompt_token_ids)
inputs = token_inputs(prompt_token_ids)
if cache_salt := parsed_content.get("cache_salt"):
inputs["cache_salt"] = cache_salt
@ -352,10 +353,7 @@ class InputPreprocessor:
prompt_text,
tokenization_kwargs=tokenization_kwargs,
)
inputs = token_inputs(
prompt=prompt_text,
prompt_token_ids=prompt_token_ids,
)
inputs = token_inputs(prompt_token_ids)
if cache_salt := parsed_content.get("cache_salt"):
inputs["cache_salt"] = cache_salt
@ -473,22 +471,17 @@ class InputPreprocessor:
decoder_inputs: SingletonInputs
if inputs["type"] == "multimodal": # Multimodal data inputs
if not ("encoder_prompt" in inputs
and "encoder_prompt_token_ids" in inputs):
if "encoder_prompt_token_ids" not in inputs:
raise RuntimeError("You should register an encoder-decoder "
"multi-modal processor for encoder-decoder "
"models.")
inputs = cast(MultiModalEncDecInputs, inputs)
encoder_inputs = token_inputs(
prompt=inputs["encoder_prompt"],
prompt_token_ids=inputs["encoder_prompt_token_ids"],
)
encoder_inputs = token_inputs(inputs["encoder_prompt_token_ids"])
decoder_prompt_inputs = decoder_inputs_to_override or inputs
decoder_inputs = MultiModalInputs(
type="multimodal",
prompt=decoder_prompt_inputs.get("prompt", ""),
prompt_token_ids=decoder_prompt_inputs["prompt_token_ids"],
mm_kwargs=inputs["mm_kwargs"],
mm_hashes=inputs["mm_hashes"],
@ -498,7 +491,7 @@ class InputPreprocessor:
decoder_inputs["cache_salt"] = cache_salt
elif inputs["type"] == "token": # Text-only inputs
encoder_inputs = token_inputs(prompt="", prompt_token_ids=[])
encoder_inputs = token_inputs(prompt_token_ids=[])
decoder_inputs = decoder_inputs_to_override or inputs
else:
assert_never(inputs) # type: ignore[arg-type]
@ -549,12 +542,14 @@ class InputPreprocessor:
decoder_inputs: Optional[SingletonInputs]
if is_explicit_encoder_decoder_prompt(prompt):
# `cast` is needed for mypy, but not pyright
prompt_ = cast(ExplicitEncoderDecoderPrompt, prompt)
encoder_inputs = self._prompt_to_llm_inputs(
prompt["encoder_prompt"],
prompt_["encoder_prompt"],
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
if (decoder_input := prompt["decoder_prompt"]) is None:
if (decoder_input := prompt_["decoder_prompt"]) is None:
decoder_inputs = None
else:
decoder_inputs = self._prompt_to_llm_inputs(decoder_input)
@ -565,8 +560,9 @@ class InputPreprocessor:
self._split_enc_dec_mm_inputs(encoder_inputs,
decoder_inputs))
else:
# `cast` is needed for mypy, but not pyright
inputs = self._prompt_to_llm_inputs(
prompt,
cast(SingletonPrompt, prompt),
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
@ -641,8 +637,9 @@ class InputPreprocessor:
"to decoder-only models")
# Decoder-only operation
# `cast` is needed for mypy, but not pyright
return self._process_decoder_only_prompt(
prompt,
cast(SingletonPrompt, prompt),
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)

View File

@ -778,7 +778,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
)
], mm_item_counts)
prompt_ids, prompt, _ = self._apply_prompt_updates(
prompt_ids, _ = self._apply_prompt_updates(
result["prompt_token_ids"],
mantis_mm_repls,
)
@ -798,7 +798,6 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
return MultiModalInputs(
type="multimodal",
prompt=prompt,
prompt_token_ids=prompt_ids,
mm_kwargs=mm_kwargs,
mm_hashes=mm_hashes,

View File

@ -219,7 +219,6 @@ class PaliGemmaMultiModalProcessor(
if len(prompt_token_ids) and prompt_token_ids[-1] != newline_token_id:
prompt_token_ids.append(newline_token_id)
mm_inputs["prompt_token_ids"] = prompt_token_ids
mm_inputs["prompt"] += newline_prompt
return mm_inputs

View File

@ -461,7 +461,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
self,
token_ids: list[int],
mm_prompt_updates: MultiModalPromptUpdates,
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
) -> tuple[list[int], Mapping[str, list[PlaceholderFeaturesInfo]]]:
# align to hf behavior when there are images
if len(mm_prompt_updates):
tokenizer = self.info.get_tokenizer()
@ -496,14 +496,14 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
for ele in sublist for e in ele
]
token_ids, text, placeholders = super()._apply_prompt_updates(
token_ids, placeholders = super()._apply_prompt_updates(
token_ids=token_ids,
mm_prompt_updates=mm_prompt_updates,
)
# Keep the behavior in line with HF processor
if text.startswith("<s> <|image|>"):
text = text.replace("<s> <|image|>", "<s><|image|>", 1)
if token_ids[:2] == tokenizer.encode("<s> <|image|>",
add_special_tokens=False):
token_ids = [token_ids[0], *token_ids[2:]]
placeholders = {
modality: [
@ -518,7 +518,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
for modality, ps in placeholders.items()
}
return token_ids, text, placeholders
return token_ids, placeholders
@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor,

View File

@ -63,7 +63,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import decode_tokens, encode_tokens
from vllm.transformers_utils.tokenizer import encode_tokens
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
@ -316,7 +316,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
mm_kwargs: MultiModalKwargsItems,
mm_prompt_updates: MultiModalPromptUpdates,
is_update_applied: bool,
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
) -> tuple[list[int], Mapping[str, list[PlaceholderFeaturesInfo]]]:
"""
Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`.
"""
@ -341,28 +341,20 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
self._validate_mm_placeholders(
mm_placeholders,
mm_item_counts,
use_audio_in_video=use_audio_in_video)
tokenizer = self.info.get_tokenizer()
prompt = decode_tokens(tokenizer, prompt_ids)
use_audio_in_video=use_audio_in_video,
)
else:
(
prompt_ids,
prompt,
mm_placeholders,
) = self._apply_prompt_updates(
prompt_ids, mm_placeholders = self._apply_prompt_updates(
prompt_ids,
mm_prompt_updates,
)
self._validate_mm_placeholders(
mm_placeholders,
mm_item_counts,
use_audio_in_video=use_audio_in_video)
use_audio_in_video=use_audio_in_video,
)
tokenizer = self.info.get_tokenizer()
prompt = decode_tokens(tokenizer, prompt_ids)
return prompt_ids, prompt, mm_placeholders
return prompt_ids, mm_placeholders
def _get_prompt_updates(
self,

View File

@ -190,7 +190,6 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
return MultiModalInputs(
type="multimodal",
prompt=prompt,
prompt_token_ids=[1],
mm_kwargs=mm_kwargs,
mm_hashes=mm_hashes,

View File

@ -453,7 +453,6 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
return MultiModalInputs(
type="multimodal",
prompt=prompt,
prompt_token_ids=prompt_ids,
mm_kwargs=mm_kwargs,
mm_hashes=mm_hashes,

View File

@ -949,9 +949,6 @@ class MultiModalInputs(TypedDict):
type: Literal["multimodal"]
"""The type of inputs."""
prompt: str
"""The processed prompt text."""
prompt_token_ids: list[int]
"""The processed token IDs which includes placeholder tokens."""
@ -980,8 +977,5 @@ class MultiModalEncDecInputs(MultiModalInputs):
ready to be passed to vLLM internals.
"""
encoder_prompt: str
"""The processed encoder prompt text."""
encoder_prompt_token_ids: list[int]
"""The processed token IDs of the encoder prompt."""

View File

@ -1878,7 +1878,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
self,
token_ids: list[int],
mm_prompt_updates: MultiModalPromptUpdates,
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
) -> tuple[list[int], Mapping[str, list[PlaceholderFeaturesInfo]]]:
tokenizer = self.info.get_tokenizer()
new_token_ids, match_result = self._apply_token_matches(
@ -1896,11 +1896,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
# Since it is inefficient to search for all possible tokenizations
# of the search text in the prompt, we instead perform string-based
# updates on the decoded token IDs, then encode them back.
if all(
if not all(
all(update_idx is not None for update_idx in update_idxs)
for update_idxs in match_result.values()):
new_text = decode_tokens(tokenizer, new_token_ids)
else:
new_text, match_result = self._apply_text_matches(
decode_tokens(tokenizer, token_ids),
mm_prompt_updates,
@ -1928,7 +1926,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
dict(matched_updates),
)
return new_token_ids, new_text, placeholders
return new_token_ids, placeholders
def _validate_mm_kwargs(
self,
@ -1976,7 +1974,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_kwargs: MultiModalKwargsOptionalItems,
mm_prompt_updates: MultiModalPromptUpdates,
is_update_applied: bool,
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
) -> tuple[list[int], Mapping[str, list[PlaceholderFeaturesInfo]]]:
mm_item_counts = mm_items.get_all_counts()
self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
@ -1986,21 +1984,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_prompt_updates,
)
self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
tokenizer = self.info.get_tokenizer()
prompt = decode_tokens(tokenizer, prompt_ids)
else:
(
prompt_ids,
prompt,
mm_placeholders,
) = self._apply_prompt_updates(
prompt_ids, mm_placeholders = self._apply_prompt_updates(
prompt_ids,
mm_prompt_updates,
)
self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
return prompt_ids, prompt, mm_placeholders
return prompt_ids, mm_placeholders
def apply(
self,
@ -2042,7 +2033,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
)
# NOTE: tokenization_kwargs are not required to init processor
prompt_ids, prompt, mm_placeholders = self._maybe_apply_prompt_updates(
prompt_ids, mm_placeholders = self._maybe_apply_prompt_updates(
mm_items=mm_items,
prompt_ids=prompt_ids,
mm_kwargs=mm_info.kwargs,
@ -2057,7 +2048,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
return MultiModalInputs(
type="multimodal",
prompt=prompt,
prompt_token_ids=prompt_ids,
mm_kwargs=mm_info.kwargs,
mm_hashes=mm_info.hashes,
@ -2100,19 +2090,15 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
tokenizer = self.info.get_tokenizer()
decoder_prompt_raw = self.create_decoder_prompt(prompt, mm_data)
if isinstance(decoder_prompt_raw, str):
decoder_prompt = decoder_prompt_raw
decoder_prompt_ids = encode_tokens(tokenizer,
decoder_prompt_raw,
add_special_tokens=False)
else:
decoder_prompt = decode_tokens(tokenizer, decoder_prompt_raw)
decoder_prompt_ids = decoder_prompt_raw
mm_inputs = MultiModalEncDecInputs(
encoder_prompt=encoder_inputs["prompt"],
encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"],
**encoder_inputs)
mm_inputs["prompt"] = decoder_prompt
mm_inputs["prompt_token_ids"] = decoder_prompt_ids
return mm_inputs

View File

@ -281,12 +281,16 @@ class AsyncLLM(EngineClient):
queue = RequestOutputCollector(output_kind=params.output_kind)
# Convert Input --> Request.
prompt_str, request = self.processor.process_inputs(
request_id, prompt, params, arrival_time, lora_request,
tokenization_kwargs, trace_headers, priority, data_parallel_rank)
request = self.processor.process_inputs(request_id, prompt, params,
arrival_time, lora_request,
tokenization_kwargs,
trace_headers, priority,
data_parallel_rank)
prompt_text = prompt if isinstance(prompt,
str) else prompt.get("prompt")
if is_pooling or params.n == 1:
await self._add_request(request, prompt_str, None, 0, queue)
await self._add_request(request, prompt_text, None, 0, queue)
return queue
# Get the updated SamplingParams from the request, which
@ -302,7 +306,7 @@ class AsyncLLM(EngineClient):
request)
child_request.request_id = request_id
child_request.sampling_params = child_params
await self._add_request(child_request, prompt_str, parent_request,
await self._add_request(child_request, prompt_text, parent_request,
idx, queue)
return queue

View File

@ -227,15 +227,18 @@ class LLMEngine:
f"request_id must be a string, got {type(request_id)}")
# Process raw inputs into the request.
prompt_str, request = self.processor.process_inputs(
request_id, prompt, params, arrival_time, lora_request,
tokenization_kwargs, trace_headers, priority)
request = self.processor.process_inputs(request_id, prompt, params,
arrival_time, lora_request,
tokenization_kwargs,
trace_headers, priority)
prompt_text = prompt if isinstance(prompt,
str) else prompt.get("prompt")
n = params.n if isinstance(params, SamplingParams) else 1
if n == 1:
# Make a new RequestState and queue.
self.output_processor.add_request(request, prompt_str, None, 0)
self.output_processor.add_request(request, prompt_text, None, 0)
# Add the request to EngineCore.
self.engine_core.add_request(request)
return
@ -249,7 +252,7 @@ class LLMEngine:
child_request.sampling_params = params
# Make a new RequestState and queue.
self.output_processor.add_request(child_request, prompt_str,
self.output_processor.add_request(child_request, prompt_text,
parent_req, idx)
# Add the request to EngineCore.
self.engine_core.add_request(child_request)

View File

@ -334,9 +334,7 @@ class Processor:
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
) -> tuple[Optional[str], EngineCoreRequest]:
# TODO(woosuk): Support pooling models.
) -> EngineCoreRequest:
self._validate_lora(lora_request)
self._validate_params(params)
@ -395,8 +393,6 @@ class Processor:
# discriminated unions of TypedDicts, because of how it handles
# inheritance of TypedDict. If we explicitly extract the items we want
# we can avoid type errors from using `dict.get` later in the method.
prompt_str: Optional[str] = None if decoder_inputs[
"type"] == "embeds" else decoder_inputs.get("prompt")
prompt_token_ids = decoder_inputs[
"prompt_token_ids"] if decoder_inputs["type"] != "embeds" else None
prompt_embeds = decoder_inputs["prompt_embeds"] if decoder_inputs[
@ -442,7 +438,7 @@ class Processor:
identifier=decoder_mm_hashes[modality][idx],
mm_position=decoder_mm_positions[modality][idx]))
return prompt_str, EngineCoreRequest(
return EngineCoreRequest(
request_id=request_id,
prompt_token_ids=prompt_token_ids,
prompt_embeds=prompt_embeds,