[Misc] Clean up input processing (#17582)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-05-02 23:11:53 +08:00 committed by GitHub
parent 3a500cd0b6
commit cb234955df
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 359 additions and 285 deletions

View File

@ -6,6 +6,7 @@ from huggingface_hub import snapshot_download
from transformers import AutoConfig, AutoModel, CLIPImageProcessor from transformers import AutoConfig, AutoModel, CLIPImageProcessor
from vllm.distributed import cleanup_dist_env_and_memory from vllm.distributed import cleanup_dist_env_and_memory
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from ....conftest import ImageTestAssets from ....conftest import ImageTestAssets
@ -14,6 +15,7 @@ from ....conftest import ImageTestAssets
DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"] DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"]
@torch.inference_mode()
def run_intern_vit_test( def run_intern_vit_test(
image_assets: ImageTestAssets, image_assets: ImageTestAssets,
model_id: str, model_id: str,
@ -21,11 +23,12 @@ def run_intern_vit_test(
dtype: str, dtype: str,
): ):
model = snapshot_download(model_id, allow_patterns=DOWNLOAD_PATTERN) model = snapshot_download(model_id, allow_patterns=DOWNLOAD_PATTERN)
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
img_processor = CLIPImageProcessor.from_pretrained(model) img_processor = CLIPImageProcessor.from_pretrained(model)
images = [asset.pil_image for asset in image_assets] images = [asset.pil_image for asset in image_assets]
pixel_values = [ pixel_values = [
img_processor(images, return_tensors='pt').pixel_values.to(dtype) img_processor(images, return_tensors='pt').pixel_values.to(torch_dtype)
for images in images for images in images
] ]
@ -34,7 +37,7 @@ def run_intern_vit_test(
config.norm_type = "rms_norm" config.norm_type = "rms_norm"
hf_model = AutoModel.from_pretrained(model, hf_model = AutoModel.from_pretrained(model,
torch_dtype=dtype, torch_dtype=torch_dtype,
trust_remote_code=True).to("cuda") trust_remote_code=True).to("cuda")
hf_outputs_per_image = [ hf_outputs_per_image = [
hf_model(pixel_value.to("cuda")).last_hidden_state hf_model(pixel_value.to("cuda")).last_hidden_state
@ -48,7 +51,7 @@ def run_intern_vit_test(
del hf_model del hf_model
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
vllm_model = vllm_model.to("cuda", dtype) vllm_model = vllm_model.to("cuda", torch_dtype)
vllm_outputs_per_image = [ vllm_outputs_per_image = [
vllm_model(pixel_values=pixel_value.to("cuda")) vllm_model(pixel_values=pixel_value.to("cuda"))
for pixel_value in pixel_values for pixel_value in pixel_values
@ -66,9 +69,8 @@ def run_intern_vit_test(
"OpenGVLab/InternViT-300M-448px", "OpenGVLab/InternViT-300M-448px",
"OpenGVLab/InternViT-6B-448px-V1-5", "OpenGVLab/InternViT-6B-448px-V1-5",
]) ])
@pytest.mark.parametrize("dtype", [torch.half]) @pytest.mark.parametrize("dtype", ["half"])
@torch.inference_mode() def test_models(dist_init, image_assets, model_id, dtype: str) -> None:
def test_models(image_assets, model_id, dtype: str) -> None:
run_intern_vit_test( run_intern_vit_test(
image_assets, image_assets,
model_id, model_id,

View File

@ -497,10 +497,6 @@ class _AsyncLLMEngine(LLMEngine):
prompt["prompt_token_ids"] = [0 prompt["prompt_token_ids"] = [0
] * prompt["prompt_embeds"].shape[-2] ] * prompt["prompt_embeds"].shape[-2]
if self.tokenizer is not None:
tokenizer = await self.get_tokenizer_async(lora_request)
self._validate_token_prompt(prompt, tokenizer=tokenizer)
processed_inputs = await self.input_preprocessor.preprocess_async( processed_inputs = await self.input_preprocessor.preprocess_async(
prompt, prompt,
lora_request=lora_request, lora_request=lora_request,

View File

@ -30,7 +30,7 @@ from vllm.entrypoints.openai.logits_processors import (
get_logits_processors as get_openai_logits_processors) get_logits_processors as get_openai_logits_processors)
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
from vllm.inputs.parse import is_token_prompt, split_enc_dec_inputs from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logits_process import get_bad_words_logits_processors from vllm.logits_process import get_bad_words_logits_processors
@ -759,11 +759,6 @@ class LLMEngine:
seq_len = prompt["prompt_embeds"].shape[0] seq_len = prompt["prompt_embeds"].shape[0]
prompt["prompt_token_ids"] = [0] * seq_len prompt["prompt_token_ids"] = [0] * seq_len
if self.tokenizer is not None:
self._validate_token_prompt(
prompt,
tokenizer=self.get_tokenizer(lora_request=lora_request))
processed_inputs = self.input_preprocessor.preprocess( processed_inputs = self.input_preprocessor.preprocess(
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
@ -782,27 +777,6 @@ class LLMEngine:
priority=priority, priority=priority,
) )
def _validate_token_prompt(self, prompt: PromptType,
tokenizer: AnyTokenizer):
# Guard against out-of-vocab tokens.
# For some tokenizers, tokenizer.decode will happily return empty text
# for token ids that are out of vocab, and we don't detect token ids
# that are greater than the max token id before running the model.
# However, these token ids will later crash a cuda kernel at runtime
# with an index out of bounds error. This will crash the entire engine.
# This needs to happen before multimodal input pre-processing, which
# may add dummy <image> tokens that aren't part of the tokenizer's
# vocabulary.
if is_token_prompt(prompt):
prompt_ids = prompt["prompt_token_ids"]
if len(prompt_ids) == 0:
# Empty prompt check is handled later
return
max_input_id = max(prompt_ids)
if max_input_id > tokenizer.max_token_id:
raise ValueError(
"Token id {} is out of vocabulary".format(max_input_id))
def _create_sequence_group_with_sampling( def _create_sequence_group_with_sampling(
self, self,
request_id: str, request_id: str,
@ -2049,6 +2023,12 @@ class LLMEngine:
else: else:
raise ValueError(f"The {prompt_type} prompt cannot be empty") raise ValueError(f"The {prompt_type} prompt cannot be empty")
if tokenizer is not None:
max_input_id = max(prompt_ids, default=0)
if max_input_id > tokenizer.max_token_id:
raise ValueError(
f"Token id {max_input_id} is out of vocabulary")
max_prompt_len = self.model_config.max_model_len max_prompt_len = self.model_config.max_model_len
if len(prompt_ids) > max_prompt_len: if len(prompt_ids) > max_prompt_len:
if prompt_type == "encoder" and model_config.is_multimodal_model: if prompt_type == "encoder" and model_config.is_multimodal_model:

View File

@ -83,6 +83,9 @@ class EngineClient(ABC):
else: else:
processed_inputs = preprocessor._prompt_to_llm_inputs(prompt) processed_inputs = preprocessor._prompt_to_llm_inputs(prompt)
if processed_inputs["type"] == "embeds":
raise NotImplementedError
prompt_token_ids = processed_inputs["prompt_token_ids"] prompt_token_ids = processed_inputs["prompt_token_ids"]
prompt_text = processed_inputs.get("prompt") prompt_text = processed_inputs.get("prompt")
multi_modal_data = processed_inputs.get("multi_modal_data") multi_modal_data = processed_inputs.get("multi_modal_data")

View File

@ -27,7 +27,7 @@ from vllm.entrypoints.score_utils import (_cosine_similarity,
_validate_score_input_lens) _validate_score_input_lens)
from vllm.entrypoints.utils import _validate_truncation_size from vllm.entrypoints.utils import _validate_truncation_size
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding.guided_fields import ( from vllm.model_executor.guided_decoding.guided_fields import (
@ -567,10 +567,12 @@ class LLM:
mm_kwargs["mm_processor_kwargs"] = prompt[ mm_kwargs["mm_processor_kwargs"] = prompt[
"mm_processor_kwargs"] "mm_processor_kwargs"]
if is_token_prompt(prompt): if "prompt_token_ids" in prompt:
prompt = cast(TokensPrompt, prompt) # Needed for mypy
prompt_tokens = prompt["prompt_token_ids"] prompt_tokens = prompt["prompt_token_ids"]
else: else:
prompt_tokens = tokenizer.encode(prompt["prompt"]) prompt_tokens = tokenizer.encode(prompt["prompt"])
instances.append( instances.append(
BeamSearchInstance(prompt_tokens, logprobs=None, **mm_kwargs)) BeamSearchInstance(prompt_tokens, logprobs=None, **mm_kwargs))

View File

@ -70,6 +70,11 @@ class EmbedsPrompt(TypedDict):
prompt_embeds: torch.Tensor prompt_embeds: torch.Tensor
"""The embeddings of the prompt.""" """The embeddings of the prompt."""
cache_salt: NotRequired[str]
"""
Optional cache salt to be used for prefix caching.
"""
SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt] SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt]
""" """
@ -195,13 +200,21 @@ class EmbedsInputs(TypedDict):
prompt_embeds: torch.Tensor prompt_embeds: torch.Tensor
"""The embeddings of the prompt.""" """The embeddings of the prompt."""
cache_salt: NotRequired[str]
"""
Optional cache salt to be used for prefix caching.
"""
def embeds_inputs(prompt_embeds: torch.Tensor) -> EmbedsInputs:
def embeds_inputs(
prompt_embeds: torch.Tensor,
cache_salt: Optional[str] = None,
) -> EmbedsInputs:
"""Construct :class:`EmbedsInputs` from optional values.""" """Construct :class:`EmbedsInputs` from optional values."""
inputs = EmbedsInputs( inputs = EmbedsInputs(type="embeds", prompt_embeds=prompt_embeds)
type="embeds",
prompt_embeds=prompt_embeds, if cache_salt is not None:
) inputs["cache_salt"] = cache_salt
return inputs return inputs

View File

@ -6,9 +6,9 @@ from typing_extensions import TypeIs
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .data import (EmbedsInputs, EmbedsPrompt, ExplicitEncoderDecoderPrompt, from .data import (EmbedsPrompt, ExplicitEncoderDecoderPrompt, ProcessorInputs,
ProcessorInputs, PromptType, SingletonInputs, PromptType, SingletonInputs, SingletonPrompt, TextPrompt,
SingletonPrompt, TextPrompt, TokensPrompt) TokensPrompt)
class ParsedText(TypedDict): class ParsedText(TypedDict):
@ -90,6 +90,10 @@ class ParsedEmbedsPrompt(TypedDict):
content: EmbedsPrompt content: EmbedsPrompt
ParsedSingletonPrompt = Union[ParsedStrPrompt, ParsedTextPrompt,
ParsedTokensPrompt, ParsedEmbedsPrompt]
@overload @overload
def parse_singleton_prompt(prompt: str) -> ParsedStrPrompt: def parse_singleton_prompt(prompt: str) -> ParsedStrPrompt:
... ...
@ -110,10 +114,7 @@ def parse_singleton_prompt(prompt: EmbedsPrompt) -> ParsedEmbedsPrompt:
... ...
def parse_singleton_prompt( def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt:
prompt: SingletonPrompt,
) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt,
ParsedEmbedsPrompt]:
if isinstance(prompt, str): if isinstance(prompt, str):
return ParsedStrPrompt(type="str", content=prompt) return ParsedStrPrompt(type="str", content=prompt)
elif isinstance(prompt, dict): elif isinstance(prompt, dict):
@ -131,23 +132,11 @@ def parse_singleton_prompt(
"inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt") "inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt")
def is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]:
return isinstance(prompt, dict) and "prompt_token_ids" in prompt
def is_embeds_prompt(prompt: PromptType) -> TypeIs[EmbedsPrompt]:
return isinstance(prompt, dict) and "prompt_embeds" in prompt
def is_explicit_encoder_decoder_prompt( def is_explicit_encoder_decoder_prompt(
prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]: prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]:
return isinstance(prompt, dict) and "encoder_prompt" in prompt return isinstance(prompt, dict) and "encoder_prompt" in prompt
def is_embeds_inputs(inputs: SingletonInputs) -> TypeIs[EmbedsInputs]:
return isinstance(inputs, dict) and inputs["type"] == "embeds"
def split_enc_dec_inputs( def split_enc_dec_inputs(
inputs: ProcessorInputs, inputs: ProcessorInputs,
) -> tuple[Optional[SingletonInputs], SingletonInputs]: ) -> tuple[Optional[SingletonInputs], SingletonInputs]:

View File

@ -14,14 +14,14 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalInputs) MultiModalInputs)
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from .data import (DecoderOnlyInputs, EmbedsInputs, EncoderDecoderInputs, from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
ProcessorInputs, PromptType, SingletonInputs, EncoderDecoderInputs, ProcessorInputs, PromptType,
SingletonPrompt, TokenInputs, embeds_inputs, token_inputs) SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs,
from .parse import (ParsedEmbedsPrompt, ParsedStrPrompt, ParsedTextPrompt, TokensPrompt, embeds_inputs, token_inputs)
ParsedTokensPrompt, is_embeds_inputs, from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
is_explicit_encoder_decoder_prompt, parse_singleton_prompt)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -140,13 +140,10 @@ class InputPreprocessor:
""" """
Prepares `decoder_input_ids` for generation with encoder-decoder models. Prepares `decoder_input_ids` for generation with encoder-decoder models.
Based on Based on:
https://github.com/huggingface/transformers/blob/4037a2b5b1278736e566aec12e169100275545ea/src/transformers/generation/utils.py
https://github.com/huggingface/transformers/blob/ specifically,
4037a2b5b1278736e566aec12e169100275545ea/ `GenerationMixin._prepare_decoder_input_ids_for_generation()`.
src/transformers/generation/utils.py
specifically GenerationMixin._prepare_decoder_input_ids_for_generation()
Arguments: Arguments:
@ -183,6 +180,23 @@ class InputPreprocessor:
return prompt_token_ids return prompt_token_ids
def _get_tokenization_kw(
self,
overrides: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
kwargs = dict[str, Any]()
if self.model_config.hf_config.model_type == "whisper":
# For Whisper, special tokens should be provided by the user based
# on the task and language of their request. Also needed to avoid
# appending an EOS token to the prompt which disrupts generation.
kwargs["add_special_tokens"] = False
if overrides:
kwargs.update(overrides)
return kwargs
def _tokenize_prompt( def _tokenize_prompt(
self, self,
prompt: str, prompt: str,
@ -194,18 +208,11 @@ class InputPreprocessor:
corresponding token IDs. corresponding token IDs.
""" """
tokenizer = self.get_tokenizer_group() tokenizer = self.get_tokenizer_group()
if tokenization_kwargs is None: tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
tokenization_kwargs = {}
if self.model_config.hf_config.model_type == "whisper": encoder_config = self.model_config.encoder_config
# For Whisper, special tokens should be provided by the user based
# on the task and language of their request. Also needed to avoid
# appending an EOS token to the prompt which disrupts generation.
tokenization_kwargs["add_special_tokens"] = False
if (self.model_config.encoder_config is not None if encoder_config and encoder_config.get("do_lower_case", False):
and self.model_config.encoder_config.get(
"do_lower_case", False)):
prompt = prompt.lower() prompt = prompt.lower()
return tokenizer.encode(prompt=prompt, return tokenizer.encode(prompt=prompt,
@ -220,18 +227,36 @@ class InputPreprocessor:
) -> list[int]: ) -> list[int]:
"""Async version of :meth:`_tokenize_prompt`.""" """Async version of :meth:`_tokenize_prompt`."""
tokenizer = self.get_tokenizer_group() tokenizer = self.get_tokenizer_group()
if tokenization_kwargs is None: tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
tokenization_kwargs = {}
if self.model_config.hf_config.model_type == "whisper":
# For Whisper, special tokens should be provided by the user based
# on the task and language of their request. Also needed to avoid
# appending an EOS token to the prompt which disrupts generation.
tokenization_kwargs["add_special_tokens"] = False
return await tokenizer.encode_async(prompt=prompt, return await tokenizer.encode_async(prompt=prompt,
lora_request=lora_request, lora_request=lora_request,
**tokenization_kwargs) **tokenization_kwargs)
def _get_mm_tokenizer(
self,
lora_request: Optional[LoRARequest],
) -> AnyTokenizer:
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer
# while using also multi-modal input
if not self.tokenizer:
return cast(AnyTokenizer, object()) # Dummy
tokenizer_group = self.get_tokenizer_group()
return tokenizer_group.get_lora_tokenizer(lora_request)
async def _get_mm_tokenizer_async(
self,
lora_request: Optional[LoRARequest],
) -> AnyTokenizer:
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer
# while using also multi-modal input
if not self.tokenizer:
return cast(AnyTokenizer, object()) # Dummy
tokenizer_group = self.get_tokenizer_group()
return await tokenizer_group.get_lora_tokenizer_async(lora_request)
def _process_multimodal( def _process_multimodal(
self, self,
prompt: Union[str, list[int]], prompt: Union[str, list[int]],
@ -244,13 +269,7 @@ class InputPreprocessor:
Apply the model's multi-modal processor to a multi-modal prompt, Apply the model's multi-modal processor to a multi-modal prompt,
returning the corresponding token IDs and metadata. returning the corresponding token IDs and metadata.
""" """
# At the moment on model (PrithviGeoSpatialMAE) requires to be tokenizer = self._get_mm_tokenizer(lora_request)
# initialized without a tokenizer while using also multi-modal input
if not self.tokenizer:
tokenizer = object() # Dummy
else:
tokenizer_group = self.get_tokenizer_group()
tokenizer = tokenizer_group.get_lora_tokenizer(lora_request)
mm_processor = self.mm_registry.create_processor(self.model_config, mm_processor = self.mm_registry.create_processor(self.model_config,
tokenizer=tokenizer) tokenizer=tokenizer)
@ -270,14 +289,7 @@ class InputPreprocessor:
return_mm_hashes: bool = False, return_mm_hashes: bool = False,
) -> MultiModalInputs: ) -> MultiModalInputs:
"""Async version of :meth:`_process_multimodal`.""" """Async version of :meth:`_process_multimodal`."""
# At the moment on model (PrithviGeoSpatialMAE) requires to be tokenizer = await self._get_mm_tokenizer_async(lora_request)
# initialized without a tokenizer while using also multi-modal input
if not self.tokenizer:
tokenizer = object() # Dummy
else:
tokenizer_group = self.get_tokenizer_group()
tokenizer = await tokenizer_group.get_lora_tokenizer_async(
lora_request)
mm_processor = self.mm_registry.create_processor(self.model_config, mm_processor = self.mm_registry.create_processor(self.model_config,
tokenizer=tokenizer) tokenizer=tokenizer)
@ -287,28 +299,160 @@ class InputPreprocessor:
return mm_processor.apply(prompt, mm_data, mm_processor_kwargs, return mm_processor.apply(prompt, mm_data, mm_processor_kwargs,
return_mm_hashes) return_mm_hashes)
def _get_prompt_data(self, parsed_prompt: Union[ParsedStrPrompt, def _process_embeds(
ParsedTextPrompt, self,
ParsedTokensPrompt]): parsed_content: EmbedsPrompt,
prompt_text = None ) -> EmbedsInputs:
prompt_token_ids = None if envs.VLLM_USE_V1:
token_type_ids = None raise ValueError("prompt_embeds is only available in V0.")
cache_salt = None
if parsed_prompt["type"] == "str": prompt_embeds = parsed_content["prompt_embeds"]
prompt_text = parsed_prompt["content"]
# prompt_embeds must be (seq_len, hidden_size), but if the user
# passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
# we can unambiguously process the intent by squeezing the batch
# dimension.
if prompt_embeds.ndim == 3:
prompt_embeds = prompt_embeds.squeeze(dim=0)
if prompt_embeds.ndim != 2:
raise ValueError(
"prompt_embeds must be of shape (seq_len, hidden_size).")
return embeds_inputs(prompt_embeds=prompt_embeds,
cache_salt=parsed_content.get("cache_salt"))
async def _process_embeds_async(
self,
parsed_content: EmbedsPrompt,
) -> EmbedsInputs:
return self._process_embeds(parsed_content)
def _process_tokens(
self,
parsed_content: TokensPrompt,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
) -> Union[TokenInputs, MultiModalInputs]:
prompt_token_ids = parsed_content["prompt_token_ids"]
token_type_ids = parsed_content.get("token_type_ids")
inputs: Union[TokenInputs, MultiModalInputs]
if multi_modal_data := parsed_content.get("multi_modal_data"):
inputs = self._process_multimodal(
prompt_token_ids,
multi_modal_data,
parsed_content.get("mm_processor_kwargs"),
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
else: else:
cache_salt = parsed_prompt["content"].get("cache_salt") inputs = token_inputs(
if parsed_prompt["type"] == "text": prompt_token_ids=prompt_token_ids,
prompt_text = parsed_prompt["content"]["prompt"] token_type_ids=token_type_ids,
elif parsed_prompt["type"] == "tokens": )
prompt_token_ids = parsed_prompt["content"].get(
"prompt_token_ids")
token_type_ids = parsed_prompt["content"].get("token_type_ids")
else:
assert_never(parsed_prompt)
return prompt_text, prompt_token_ids, token_type_ids, cache_salt if cache_salt := parsed_content.get("cache_salt"):
inputs["cache_salt"] = cache_salt
return inputs
async def _process_tokens_async(
self,
parsed_content: TokensPrompt,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
) -> Union[TokenInputs, MultiModalInputs]:
prompt_token_ids = parsed_content["prompt_token_ids"]
token_type_ids = parsed_content.get("token_type_ids")
inputs: Union[TokenInputs, MultiModalInputs]
if multi_modal_data := parsed_content.get("multi_modal_data"):
inputs = await self._process_multimodal_async(
prompt_token_ids,
multi_modal_data,
parsed_content.get("mm_processor_kwargs"),
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
else:
inputs = token_inputs(
prompt_token_ids=prompt_token_ids,
token_type_ids=token_type_ids,
)
if cache_salt := parsed_content.get("cache_salt"):
inputs["cache_salt"] = cache_salt
return inputs
def _process_text(
self,
parsed_content: TextPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
) -> Union[TokenInputs, MultiModalInputs]:
prompt_text = parsed_content["prompt"]
inputs: Union[TokenInputs, MultiModalInputs]
if multi_modal_data := parsed_content.get("multi_modal_data"):
inputs = self._process_multimodal(
prompt_text,
multi_modal_data,
parsed_content.get("mm_processor_kwargs"),
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
else:
prompt_token_ids = self._tokenize_prompt(
prompt_text,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
inputs = token_inputs(
prompt=prompt_text,
prompt_token_ids=prompt_token_ids,
)
if cache_salt := parsed_content.get("cache_salt"):
inputs["cache_salt"] = cache_salt
return inputs
async def _process_text_async(
self,
parsed_content: TextPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
) -> Union[TokenInputs, MultiModalInputs]:
prompt_text = parsed_content["prompt"]
inputs: Union[TokenInputs, MultiModalInputs]
if multi_modal_data := parsed_content.get("multi_modal_data"):
inputs = await self._process_multimodal_async(
prompt_text,
multi_modal_data,
parsed_content.get("mm_processor_kwargs"),
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
else:
prompt_token_ids = await self._tokenize_prompt_async(
prompt_text,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
inputs = token_inputs(
prompt=prompt_text,
prompt_token_ids=prompt_token_ids,
)
if cache_salt := parsed_content.get("cache_salt"):
inputs["cache_salt"] = cache_salt
return inputs
def _prompt_to_llm_inputs( def _prompt_to_llm_inputs(
self, self,
@ -333,38 +477,27 @@ class InputPreprocessor:
parsed = parse_singleton_prompt(prompt) parsed = parse_singleton_prompt(prompt)
if parsed["type"] == "embeds": if parsed["type"] == "embeds":
return self._process_prompt_embeds(parsed) return self._process_embeds(parsed["content"])
if parsed["type"] == "tokens":
prompt_text, prompt_token_ids, token_type_ids, cache_salt = \ return self._process_tokens(
self._get_prompt_data(parsed) parsed["content"],
# If multimodal data is present, process and return immediately
if parsed["type"] != "str" and parsed["content"].get(
"multi_modal_data") is not None:
inputs = self._process_multimodal(
prompt_text if prompt_text is not None else prompt_token_ids,
parsed["content"]["multi_modal_data"],
parsed["content"].get("mm_processor_kwargs"),
lora_request=lora_request, lora_request=lora_request,
return_mm_hashes=return_mm_hashes, return_mm_hashes=return_mm_hashes,
) )
if cache_salt is not None: if parsed["type"] == "text":
inputs["cache_salt"] = cache_salt return self._process_text(
return inputs parsed["content"],
if prompt_token_ids is None:
prompt_token_ids = self._tokenize_prompt(
prompt_text,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
if parsed["type"] == "str":
return self._process_text(
TextPrompt(prompt=parsed["content"]),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
) )
return token_inputs(
prompt=prompt_text,
prompt_token_ids=prompt_token_ids,
token_type_ids=token_type_ids,
cache_salt=cache_salt,
)
assert_never(parsed) assert_never(parsed)
@ -375,79 +508,49 @@ class InputPreprocessor:
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False, return_mm_hashes: bool = False,
) -> SingletonInputs: ) -> SingletonInputs:
"""Async version of :meth:`_extract_prompt_components`.""" """Async version of :meth:`_prompt_to_llm_inputs`."""
parsed = parse_singleton_prompt(prompt) parsed = parse_singleton_prompt(prompt)
if parsed["type"] == "embeds": if parsed["type"] == "embeds":
return self._process_prompt_embeds(parsed) return await self._process_embeds_async(parsed["content"])
if parsed["type"] == "tokens":
prompt_text, prompt_token_ids, token_type_ids, cache_salt = \ return await self._process_tokens_async(
self._get_prompt_data(parsed) parsed["content"],
if parsed["type"] != "str" and parsed["content"].get(
"multi_modal_data") is not None:
inputs = await self._process_multimodal_async(
prompt_token_ids if prompt_text is None else prompt_text,
parsed["content"]["multi_modal_data"],
parsed["content"].get("mm_processor_kwargs"),
lora_request=lora_request, lora_request=lora_request,
return_mm_hashes=return_mm_hashes, return_mm_hashes=return_mm_hashes,
) )
if cache_salt is not None: if parsed["type"] == "text":
inputs["cache_salt"] = cache_salt return await self._process_text_async(
return inputs parsed["content"],
if prompt_token_ids is None:
prompt_token_ids = await self._tokenize_prompt_async(
prompt_text,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
if parsed["type"] == "str":
return await self._process_text_async(
TextPrompt(prompt=parsed["content"]),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
) )
return token_inputs(
prompt=prompt_text,
prompt_token_ids=prompt_token_ids,
token_type_ids=token_type_ids,
cache_salt=cache_salt,
)
def _process_prompt_embeds(self,
parsed: ParsedEmbedsPrompt) -> EmbedsInputs:
if envs.VLLM_USE_V1:
raise ValueError("prompt_embeds is only available in V0.")
prompt_embeds_content = parsed["content"]
prompt_embeds = prompt_embeds_content["prompt_embeds"]
# prompt_embeds must be (seq_len, hidden_size), but if the user
# passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
# we can unambiguously process the intent by squeezing the batch
# dimension.
if prompt_embeds.ndim == 3 and prompt_embeds.shape[0] == 1:
prompt_embeds = prompt_embeds.squeeze(dim=0)
if prompt_embeds.ndim != 2:
raise ValueError(
"prompt_embeds must be of shape (seq_len, hidden_size).")
return embeds_inputs(prompt_embeds=prompt_embeds)
assert_never(parsed) assert_never(parsed)
def _build_enc_dec_llm_inputs( def _build_enc_dec_llm_inputs(
self, self,
encoder_inputs: Union[TokenInputs, MultiModalInputs], encoder_inputs: SingletonInputs,
decoder_inputs: Optional[Union[TokenInputs, MultiModalInputs]], decoder_inputs: Optional[SingletonInputs],
) -> EncoderDecoderInputs: ) -> EncoderDecoderInputs:
if (encoder_inputs["type"] == "token" if (encoder_inputs["type"] == "embeds"
or encoder_inputs["type"] == "multimodal"): or decoder_inputs and decoder_inputs["type"] == "embeds"):
pass raise ValueError("Embedding inputs are not supported for encoder-"
else: "decoder models")
assert_never(encoder_inputs) # type: ignore[arg-type]
# Mypy does not correctly infer that EmbedsInputs is impossible # Needed for mypy
assert "prompt_token_ids" in encoder_inputs encoder_inputs = cast(Union[TokenInputs, MultiModalInputs],
encoder_inputs)
decoder_inputs = cast(Optional[Union[TokenInputs, MultiModalInputs]],
decoder_inputs)
if decoder_inputs is None: if decoder_inputs is None:
if self.model_config.hf_config.model_type == "whisper": if self.model_config.hf_config.model_type == "whisper":
@ -460,74 +563,78 @@ class InputPreprocessor:
dec_token_ids = self._prepare_decoder_input_ids_for_generation( dec_token_ids = self._prepare_decoder_input_ids_for_generation(
None) None)
decoder_inputs = token_inputs(dec_token_ids) decoder_inputs = token_inputs(dec_token_ids)
elif (decoder_inputs["type"] == "token" else:
or decoder_inputs["type"] == "multimodal"):
dec_token_ids = self._prepare_decoder_input_ids_for_generation(
decoder_inputs["prompt_token_ids"])
decoder_inputs["prompt_token_ids"] = dec_token_ids
if "multi_modal_data" in decoder_inputs: if "multi_modal_data" in decoder_inputs:
raise ValueError("Multi-modal decoder inputs of encoder-" raise ValueError("Multi-modal decoder inputs of encoder-"
"decoder models are not supported yet") "decoder models are not supported yet")
else:
assert_never(encoder_inputs) # type: ignore[arg-type] dec_token_ids = self._prepare_decoder_input_ids_for_generation(
decoder_inputs["prompt_token_ids"])
decoder_inputs["prompt_token_ids"] = dec_token_ids
return EncoderDecoderInputs( return EncoderDecoderInputs(
encoder=encoder_inputs, encoder=encoder_inputs,
decoder=decoder_inputs, decoder=decoder_inputs,
) )
def _separate_enc_dec_inputs_from_mm_processor_outputs( def _split_enc_dec_mm_inputs(
self, self,
inputs: SingletonInputs, inputs: Union[SingletonInputs, MultiModalEncDecInputs],
decoder_inputs_to_override: Optional[Union[TokenInputs, decoder_inputs_to_override: Optional[SingletonInputs] = None,
MultiModalInputs]] = None,
) -> tuple[SingletonInputs, SingletonInputs]: ) -> tuple[SingletonInputs, SingletonInputs]:
""" """
For encoder/decoder models only: For encoder/decoder models only:
Separate Encoder/Decoder inputs from a MultiModalEncDecInputs Separate Encoder/Decoder inputs from a MultiModalEncDecInputs
""" """
if (inputs["type"] == "embeds" or decoder_inputs_to_override
and decoder_inputs_to_override["type"] == "embeds"):
raise ValueError("Embedding inputs are not supported for encoder-"
"decoder models")
# Needed for mypy
inputs = cast(
Union[TokenInputs, MultiModalInputs, MultiModalEncDecInputs],
inputs,
)
decoder_inputs_to_override = cast(
Optional[Union[TokenInputs, MultiModalInputs]],
decoder_inputs_to_override,
)
encoder_inputs: SingletonInputs encoder_inputs: SingletonInputs
decoder_inputs: SingletonInputs decoder_inputs: SingletonInputs
if inputs["type"] == "multimodal":
# Multimodal data inputs if inputs["type"] == "multimodal": # Multimodal data inputs
assert ("encoder_prompt" in inputs if not ("encoder_prompt" in inputs
and "encoder_prompt_token_ids" in inputs) and "encoder_prompt_token_ids" in inputs):
raise RuntimeError("You should register an encoder-decoder "
"multi-modal processor for encoder-decoder "
"models.")
inputs = cast(MultiModalEncDecInputs, inputs) inputs = cast(MultiModalEncDecInputs, inputs)
encoder_inputs = token_inputs( encoder_inputs = token_inputs(
prompt=inputs["encoder_prompt"], prompt=inputs["encoder_prompt"],
prompt_token_ids=inputs["encoder_prompt_token_ids"], prompt_token_ids=inputs["encoder_prompt_token_ids"],
) )
if decoder_inputs_to_override is not None:
decoder_inputs = MultiModalInputs(
type="multimodal",
prompt=decoder_inputs_to_override.get("prompt", ""),
prompt_token_ids=decoder_inputs_to_override[
"prompt_token_ids"],
mm_kwargs=inputs["mm_kwargs"],
mm_hashes=inputs["mm_hashes"],
mm_placeholders=inputs["mm_placeholders"],
)
else:
decoder_inputs = MultiModalInputs(
type="multimodal",
prompt=inputs["prompt"],
prompt_token_ids=inputs["prompt_token_ids"],
mm_kwargs=inputs["mm_kwargs"],
mm_hashes=inputs["mm_hashes"],
mm_placeholders=inputs["mm_placeholders"],
)
cache_salt = inputs.get("cache_salt") decoder_prompt_inputs = decoder_inputs_to_override or inputs
if cache_salt is not None: decoder_inputs = MultiModalInputs(
type="multimodal",
prompt=decoder_prompt_inputs.get("prompt", ""),
prompt_token_ids=decoder_prompt_inputs["prompt_token_ids"],
mm_kwargs=inputs["mm_kwargs"],
mm_hashes=inputs["mm_hashes"],
mm_placeholders=inputs["mm_placeholders"],
)
if cache_salt := inputs.get("cache_salt"):
decoder_inputs["cache_salt"] = cache_salt decoder_inputs["cache_salt"] = cache_salt
elif inputs["type"] == "token": elif inputs["type"] == "token": # Text-only inputs
# Text-only inputs
encoder_inputs = token_inputs(prompt="", prompt_token_ids=[]) encoder_inputs = token_inputs(prompt="", prompt_token_ids=[])
decoder_inputs = decoder_inputs_to_override or inputs decoder_inputs = decoder_inputs_to_override or inputs
else: else:
assert_never(inputs) # type: ignore[arg-type] assert_never(inputs) # type: ignore[arg-type]
return encoder_inputs, decoder_inputs return encoder_inputs, decoder_inputs
def _process_encoder_decoder_prompt( def _process_encoder_decoder_prompt(
@ -580,11 +687,9 @@ class InputPreprocessor:
# For multimodal model, override decoder prompt from processor # For multimodal model, override decoder prompt from processor
# with explicit decoder prompt. # with explicit decoder prompt.
if self.model_config.is_multimodal_model: if self.model_config.is_multimodal_model:
assert decoder_inputs is None or not is_embeds_inputs(
decoder_inputs)
encoder_inputs, decoder_inputs = ( encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs( self._split_enc_dec_mm_inputs(encoder_inputs,
encoder_inputs, decoder_inputs)) decoder_inputs))
else: else:
inputs = self._prompt_to_llm_inputs( inputs = self._prompt_to_llm_inputs(
prompt, prompt,
@ -593,16 +698,11 @@ class InputPreprocessor:
if self.model_config.is_multimodal_model: if self.model_config.is_multimodal_model:
# Encoder-Decoder Multimodal model # Encoder-Decoder Multimodal model
encoder_inputs, decoder_inputs = ( encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs( self._split_enc_dec_mm_inputs(inputs))
inputs))
else: else:
encoder_inputs = inputs encoder_inputs = inputs
decoder_inputs = None decoder_inputs = None
# Mypy does not do type inference well with TypedDicts with Literal
# values.
assert not is_embeds_inputs(encoder_inputs)
assert decoder_inputs is None or not is_embeds_inputs(decoder_inputs)
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs) return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
async def _process_encoder_decoder_prompt_async( async def _process_encoder_decoder_prompt_async(
@ -635,11 +735,9 @@ class InputPreprocessor:
# For multimodal model, override decoder prompt from processor # For multimodal model, override decoder prompt from processor
# with explicit decoder prompt. # with explicit decoder prompt.
if self.model_config.is_multimodal_model: if self.model_config.is_multimodal_model:
assert decoder_inputs is None or not is_embeds_inputs(
decoder_inputs)
encoder_inputs, decoder_inputs = ( encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs( self._split_enc_dec_mm_inputs(encoder_inputs,
encoder_inputs, decoder_inputs)) decoder_inputs))
else: else:
inputs = await self._prompt_to_llm_inputs_async( inputs = await self._prompt_to_llm_inputs_async(
prompt, prompt,
@ -648,16 +746,11 @@ class InputPreprocessor:
if self.model_config.is_multimodal_model: if self.model_config.is_multimodal_model:
# Encoder-Decoder Multimodal model # Encoder-Decoder Multimodal model
encoder_inputs, decoder_inputs = ( encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs( self._split_enc_dec_mm_inputs(inputs))
inputs))
else: else:
encoder_inputs = inputs encoder_inputs = inputs
decoder_inputs = None decoder_inputs = None
# Mypy does not do type inference well with TypedDicts with Literal
# values.
assert not is_embeds_inputs(encoder_inputs)
assert decoder_inputs is None or not is_embeds_inputs(decoder_inputs)
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs) return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
def _build_decoder_only_llm_inputs( def _build_decoder_only_llm_inputs(
@ -665,19 +758,13 @@ class InputPreprocessor:
prompt_inputs: DecoderOnlyInputs, prompt_inputs: DecoderOnlyInputs,
prompt_adapter_request: Optional[PromptAdapterRequest], prompt_adapter_request: Optional[PromptAdapterRequest],
) -> DecoderOnlyInputs: ) -> DecoderOnlyInputs:
if (prompt_inputs["type"] == "token" if "prompt_token_ids" in prompt_inputs:
or prompt_inputs["type"] == "multimodal"): prompt_inputs = cast(Union[TokenInputs, MultiModalInputs],
# Mypy does not do type inference well with typedicts and Literal prompt_inputs) # Needed for mypy
# values
assert not is_embeds_inputs(prompt_inputs)
prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter( prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter(
prompt_inputs["prompt_token_ids"], prompt_inputs["prompt_token_ids"],
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
) )
elif (prompt_inputs["type"] == "embeds"):
pass
else:
assert_never(prompt_inputs) # type: ignore[arg-type]
return prompt_inputs return prompt_inputs

View File

@ -1670,15 +1670,17 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
placeholders = mm_placeholders.get(modality, []) placeholders = mm_placeholders.get(modality, [])
if len(placeholders) != item_count: if len(placeholders) != item_count:
# NOTE: If you are a model developer, this can also arise from
# an inconsistency between `_call_hf_processor` and
# `_get_mm_fields_config` implementations
raise RuntimeError( raise RuntimeError(
f"Expected there to be {item_count} prompt updates " f"Expected there to be {item_count} prompt updates "
f"corresponding to {item_count} {modality} items, but " f"corresponding to {item_count} {modality} items, but "
f"instead found {len(placeholders)} prompt updates! " f"instead found {len(placeholders)} prompt updates! "
"Either the prompt text has missing/incorrect tokens for " "This is likely because you forgot to include input "
"multi-modal inputs, or there is a problem with your " "placeholder tokens (e.g., `<image>`, `<|image_pad|>`) "
"implementation of merged multi-modal processor for this " "in the prompt. If the model has a chat template, make "
"model (usually arising from an inconsistency between " "sure you have applied it before calling `LLM.generate`.")
"`_call_hf_processor` and `_get_prompt_updates`).")
def _maybe_apply_prompt_updates( def _maybe_apply_prompt_updates(
self, self,