[Misc] Clean up input processing (#17582)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-05-02 23:11:53 +08:00 committed by GitHub
parent 3a500cd0b6
commit cb234955df
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 359 additions and 285 deletions

View File

@ -6,6 +6,7 @@ from huggingface_hub import snapshot_download
from transformers import AutoConfig, AutoModel, CLIPImageProcessor
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from ....conftest import ImageTestAssets
@ -14,6 +15,7 @@ from ....conftest import ImageTestAssets
DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"]
@torch.inference_mode()
def run_intern_vit_test(
image_assets: ImageTestAssets,
model_id: str,
@ -21,11 +23,12 @@ def run_intern_vit_test(
dtype: str,
):
model = snapshot_download(model_id, allow_patterns=DOWNLOAD_PATTERN)
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
img_processor = CLIPImageProcessor.from_pretrained(model)
images = [asset.pil_image for asset in image_assets]
pixel_values = [
img_processor(images, return_tensors='pt').pixel_values.to(dtype)
img_processor(images, return_tensors='pt').pixel_values.to(torch_dtype)
for images in images
]
@ -34,7 +37,7 @@ def run_intern_vit_test(
config.norm_type = "rms_norm"
hf_model = AutoModel.from_pretrained(model,
torch_dtype=dtype,
torch_dtype=torch_dtype,
trust_remote_code=True).to("cuda")
hf_outputs_per_image = [
hf_model(pixel_value.to("cuda")).last_hidden_state
@ -48,7 +51,7 @@ def run_intern_vit_test(
del hf_model
cleanup_dist_env_and_memory()
vllm_model = vllm_model.to("cuda", dtype)
vllm_model = vllm_model.to("cuda", torch_dtype)
vllm_outputs_per_image = [
vllm_model(pixel_values=pixel_value.to("cuda"))
for pixel_value in pixel_values
@ -66,9 +69,8 @@ def run_intern_vit_test(
"OpenGVLab/InternViT-300M-448px",
"OpenGVLab/InternViT-6B-448px-V1-5",
])
@pytest.mark.parametrize("dtype", [torch.half])
@torch.inference_mode()
def test_models(image_assets, model_id, dtype: str) -> None:
@pytest.mark.parametrize("dtype", ["half"])
def test_models(dist_init, image_assets, model_id, dtype: str) -> None:
run_intern_vit_test(
image_assets,
model_id,

View File

@ -497,10 +497,6 @@ class _AsyncLLMEngine(LLMEngine):
prompt["prompt_token_ids"] = [0
] * prompt["prompt_embeds"].shape[-2]
if self.tokenizer is not None:
tokenizer = await self.get_tokenizer_async(lora_request)
self._validate_token_prompt(prompt, tokenizer=tokenizer)
processed_inputs = await self.input_preprocessor.preprocess_async(
prompt,
lora_request=lora_request,

View File

@ -30,7 +30,7 @@ from vllm.entrypoints.openai.logits_processors import (
get_logits_processors as get_openai_logits_processors)
from vllm.executor.executor_base import ExecutorBase
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
from vllm.inputs.parse import is_token_prompt, split_enc_dec_inputs
from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.logits_process import get_bad_words_logits_processors
@ -759,11 +759,6 @@ class LLMEngine:
seq_len = prompt["prompt_embeds"].shape[0]
prompt["prompt_token_ids"] = [0] * seq_len
if self.tokenizer is not None:
self._validate_token_prompt(
prompt,
tokenizer=self.get_tokenizer(lora_request=lora_request))
processed_inputs = self.input_preprocessor.preprocess(
prompt,
tokenization_kwargs=tokenization_kwargs,
@ -782,27 +777,6 @@ class LLMEngine:
priority=priority,
)
def _validate_token_prompt(self, prompt: PromptType,
tokenizer: AnyTokenizer):
# Guard against out-of-vocab tokens.
# For some tokenizers, tokenizer.decode will happily return empty text
# for token ids that are out of vocab, and we don't detect token ids
# that are greater than the max token id before running the model.
# However, these token ids will later crash a cuda kernel at runtime
# with an index out of bounds error. This will crash the entire engine.
# This needs to happen before multimodal input pre-processing, which
# may add dummy <image> tokens that aren't part of the tokenizer's
# vocabulary.
if is_token_prompt(prompt):
prompt_ids = prompt["prompt_token_ids"]
if len(prompt_ids) == 0:
# Empty prompt check is handled later
return
max_input_id = max(prompt_ids)
if max_input_id > tokenizer.max_token_id:
raise ValueError(
"Token id {} is out of vocabulary".format(max_input_id))
def _create_sequence_group_with_sampling(
self,
request_id: str,
@ -2049,6 +2023,12 @@ class LLMEngine:
else:
raise ValueError(f"The {prompt_type} prompt cannot be empty")
if tokenizer is not None:
max_input_id = max(prompt_ids, default=0)
if max_input_id > tokenizer.max_token_id:
raise ValueError(
f"Token id {max_input_id} is out of vocabulary")
max_prompt_len = self.model_config.max_model_len
if len(prompt_ids) > max_prompt_len:
if prompt_type == "encoder" and model_config.is_multimodal_model:

View File

@ -83,6 +83,9 @@ class EngineClient(ABC):
else:
processed_inputs = preprocessor._prompt_to_llm_inputs(prompt)
if processed_inputs["type"] == "embeds":
raise NotImplementedError
prompt_token_ids = processed_inputs["prompt_token_ids"]
prompt_text = processed_inputs.get("prompt")
multi_modal_data = processed_inputs.get("multi_modal_data")

View File

@ -27,7 +27,7 @@ from vllm.entrypoints.score_utils import (_cosine_similarity,
_validate_score_input_lens)
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding.guided_fields import (
@ -567,10 +567,12 @@ class LLM:
mm_kwargs["mm_processor_kwargs"] = prompt[
"mm_processor_kwargs"]
if is_token_prompt(prompt):
if "prompt_token_ids" in prompt:
prompt = cast(TokensPrompt, prompt) # Needed for mypy
prompt_tokens = prompt["prompt_token_ids"]
else:
prompt_tokens = tokenizer.encode(prompt["prompt"])
instances.append(
BeamSearchInstance(prompt_tokens, logprobs=None, **mm_kwargs))

View File

@ -70,6 +70,11 @@ class EmbedsPrompt(TypedDict):
prompt_embeds: torch.Tensor
"""The embeddings of the prompt."""
cache_salt: NotRequired[str]
"""
Optional cache salt to be used for prefix caching.
"""
SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt]
"""
@ -195,13 +200,21 @@ class EmbedsInputs(TypedDict):
prompt_embeds: torch.Tensor
"""The embeddings of the prompt."""
cache_salt: NotRequired[str]
"""
Optional cache salt to be used for prefix caching.
"""
def embeds_inputs(prompt_embeds: torch.Tensor) -> EmbedsInputs:
def embeds_inputs(
prompt_embeds: torch.Tensor,
cache_salt: Optional[str] = None,
) -> EmbedsInputs:
"""Construct :class:`EmbedsInputs` from optional values."""
inputs = EmbedsInputs(
type="embeds",
prompt_embeds=prompt_embeds,
)
inputs = EmbedsInputs(type="embeds", prompt_embeds=prompt_embeds)
if cache_salt is not None:
inputs["cache_salt"] = cache_salt
return inputs

View File

@ -6,9 +6,9 @@ from typing_extensions import TypeIs
from vllm.utils import is_list_of
from .data import (EmbedsInputs, EmbedsPrompt, ExplicitEncoderDecoderPrompt,
ProcessorInputs, PromptType, SingletonInputs,
SingletonPrompt, TextPrompt, TokensPrompt)
from .data import (EmbedsPrompt, ExplicitEncoderDecoderPrompt, ProcessorInputs,
PromptType, SingletonInputs, SingletonPrompt, TextPrompt,
TokensPrompt)
class ParsedText(TypedDict):
@ -90,6 +90,10 @@ class ParsedEmbedsPrompt(TypedDict):
content: EmbedsPrompt
ParsedSingletonPrompt = Union[ParsedStrPrompt, ParsedTextPrompt,
ParsedTokensPrompt, ParsedEmbedsPrompt]
@overload
def parse_singleton_prompt(prompt: str) -> ParsedStrPrompt:
...
@ -110,10 +114,7 @@ def parse_singleton_prompt(prompt: EmbedsPrompt) -> ParsedEmbedsPrompt:
...
def parse_singleton_prompt(
prompt: SingletonPrompt,
) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt,
ParsedEmbedsPrompt]:
def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt:
if isinstance(prompt, str):
return ParsedStrPrompt(type="str", content=prompt)
elif isinstance(prompt, dict):
@ -131,23 +132,11 @@ def parse_singleton_prompt(
"inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt")
def is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]:
return isinstance(prompt, dict) and "prompt_token_ids" in prompt
def is_embeds_prompt(prompt: PromptType) -> TypeIs[EmbedsPrompt]:
return isinstance(prompt, dict) and "prompt_embeds" in prompt
def is_explicit_encoder_decoder_prompt(
prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]:
return isinstance(prompt, dict) and "encoder_prompt" in prompt
def is_embeds_inputs(inputs: SingletonInputs) -> TypeIs[EmbedsInputs]:
return isinstance(inputs, dict) and inputs["type"] == "embeds"
def split_enc_dec_inputs(
inputs: ProcessorInputs,
) -> tuple[Optional[SingletonInputs], SingletonInputs]:

View File

@ -14,14 +14,14 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalInputs)
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from .data import (DecoderOnlyInputs, EmbedsInputs, EncoderDecoderInputs,
ProcessorInputs, PromptType, SingletonInputs,
SingletonPrompt, TokenInputs, embeds_inputs, token_inputs)
from .parse import (ParsedEmbedsPrompt, ParsedStrPrompt, ParsedTextPrompt,
ParsedTokensPrompt, is_embeds_inputs,
is_explicit_encoder_decoder_prompt, parse_singleton_prompt)
from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
EncoderDecoderInputs, ProcessorInputs, PromptType,
SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs,
TokensPrompt, embeds_inputs, token_inputs)
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
logger = init_logger(__name__)
@ -140,13 +140,10 @@ class InputPreprocessor:
"""
Prepares `decoder_input_ids` for generation with encoder-decoder models.
Based on
https://github.com/huggingface/transformers/blob/
4037a2b5b1278736e566aec12e169100275545ea/
src/transformers/generation/utils.py
specifically GenerationMixin._prepare_decoder_input_ids_for_generation()
Based on:
https://github.com/huggingface/transformers/blob/4037a2b5b1278736e566aec12e169100275545ea/src/transformers/generation/utils.py
specifically,
`GenerationMixin._prepare_decoder_input_ids_for_generation()`.
Arguments:
@ -183,6 +180,23 @@ class InputPreprocessor:
return prompt_token_ids
def _get_tokenization_kw(
self,
overrides: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
kwargs = dict[str, Any]()
if self.model_config.hf_config.model_type == "whisper":
# For Whisper, special tokens should be provided by the user based
# on the task and language of their request. Also needed to avoid
# appending an EOS token to the prompt which disrupts generation.
kwargs["add_special_tokens"] = False
if overrides:
kwargs.update(overrides)
return kwargs
def _tokenize_prompt(
self,
prompt: str,
@ -194,18 +208,11 @@ class InputPreprocessor:
corresponding token IDs.
"""
tokenizer = self.get_tokenizer_group()
if tokenization_kwargs is None:
tokenization_kwargs = {}
tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
if self.model_config.hf_config.model_type == "whisper":
# For Whisper, special tokens should be provided by the user based
# on the task and language of their request. Also needed to avoid
# appending an EOS token to the prompt which disrupts generation.
tokenization_kwargs["add_special_tokens"] = False
encoder_config = self.model_config.encoder_config
if (self.model_config.encoder_config is not None
and self.model_config.encoder_config.get(
"do_lower_case", False)):
if encoder_config and encoder_config.get("do_lower_case", False):
prompt = prompt.lower()
return tokenizer.encode(prompt=prompt,
@ -220,18 +227,36 @@ class InputPreprocessor:
) -> list[int]:
"""Async version of :meth:`_tokenize_prompt`."""
tokenizer = self.get_tokenizer_group()
if tokenization_kwargs is None:
tokenization_kwargs = {}
tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
if self.model_config.hf_config.model_type == "whisper":
# For Whisper, special tokens should be provided by the user based
# on the task and language of their request. Also needed to avoid
# appending an EOS token to the prompt which disrupts generation.
tokenization_kwargs["add_special_tokens"] = False
return await tokenizer.encode_async(prompt=prompt,
lora_request=lora_request,
**tokenization_kwargs)
def _get_mm_tokenizer(
self,
lora_request: Optional[LoRARequest],
) -> AnyTokenizer:
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer
# while using also multi-modal input
if not self.tokenizer:
return cast(AnyTokenizer, object()) # Dummy
tokenizer_group = self.get_tokenizer_group()
return tokenizer_group.get_lora_tokenizer(lora_request)
async def _get_mm_tokenizer_async(
self,
lora_request: Optional[LoRARequest],
) -> AnyTokenizer:
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer
# while using also multi-modal input
if not self.tokenizer:
return cast(AnyTokenizer, object()) # Dummy
tokenizer_group = self.get_tokenizer_group()
return await tokenizer_group.get_lora_tokenizer_async(lora_request)
def _process_multimodal(
self,
prompt: Union[str, list[int]],
@ -244,13 +269,7 @@ class InputPreprocessor:
Apply the model's multi-modal processor to a multi-modal prompt,
returning the corresponding token IDs and metadata.
"""
# At the moment on model (PrithviGeoSpatialMAE) requires to be
# initialized without a tokenizer while using also multi-modal input
if not self.tokenizer:
tokenizer = object() # Dummy
else:
tokenizer_group = self.get_tokenizer_group()
tokenizer = tokenizer_group.get_lora_tokenizer(lora_request)
tokenizer = self._get_mm_tokenizer(lora_request)
mm_processor = self.mm_registry.create_processor(self.model_config,
tokenizer=tokenizer)
@ -270,14 +289,7 @@ class InputPreprocessor:
return_mm_hashes: bool = False,
) -> MultiModalInputs:
"""Async version of :meth:`_process_multimodal`."""
# At the moment on model (PrithviGeoSpatialMAE) requires to be
# initialized without a tokenizer while using also multi-modal input
if not self.tokenizer:
tokenizer = object() # Dummy
else:
tokenizer_group = self.get_tokenizer_group()
tokenizer = await tokenizer_group.get_lora_tokenizer_async(
lora_request)
tokenizer = await self._get_mm_tokenizer_async(lora_request)
mm_processor = self.mm_registry.create_processor(self.model_config,
tokenizer=tokenizer)
@ -287,28 +299,160 @@ class InputPreprocessor:
return mm_processor.apply(prompt, mm_data, mm_processor_kwargs,
return_mm_hashes)
def _get_prompt_data(self, parsed_prompt: Union[ParsedStrPrompt,
ParsedTextPrompt,
ParsedTokensPrompt]):
prompt_text = None
prompt_token_ids = None
token_type_ids = None
cache_salt = None
def _process_embeds(
self,
parsed_content: EmbedsPrompt,
) -> EmbedsInputs:
if envs.VLLM_USE_V1:
raise ValueError("prompt_embeds is only available in V0.")
if parsed_prompt["type"] == "str":
prompt_text = parsed_prompt["content"]
prompt_embeds = parsed_content["prompt_embeds"]
# prompt_embeds must be (seq_len, hidden_size), but if the user
# passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
# we can unambiguously process the intent by squeezing the batch
# dimension.
if prompt_embeds.ndim == 3:
prompt_embeds = prompt_embeds.squeeze(dim=0)
if prompt_embeds.ndim != 2:
raise ValueError(
"prompt_embeds must be of shape (seq_len, hidden_size).")
return embeds_inputs(prompt_embeds=prompt_embeds,
cache_salt=parsed_content.get("cache_salt"))
async def _process_embeds_async(
self,
parsed_content: EmbedsPrompt,
) -> EmbedsInputs:
return self._process_embeds(parsed_content)
def _process_tokens(
self,
parsed_content: TokensPrompt,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
) -> Union[TokenInputs, MultiModalInputs]:
prompt_token_ids = parsed_content["prompt_token_ids"]
token_type_ids = parsed_content.get("token_type_ids")
inputs: Union[TokenInputs, MultiModalInputs]
if multi_modal_data := parsed_content.get("multi_modal_data"):
inputs = self._process_multimodal(
prompt_token_ids,
multi_modal_data,
parsed_content.get("mm_processor_kwargs"),
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
else:
cache_salt = parsed_prompt["content"].get("cache_salt")
if parsed_prompt["type"] == "text":
prompt_text = parsed_prompt["content"]["prompt"]
elif parsed_prompt["type"] == "tokens":
prompt_token_ids = parsed_prompt["content"].get(
"prompt_token_ids")
token_type_ids = parsed_prompt["content"].get("token_type_ids")
else:
assert_never(parsed_prompt)
inputs = token_inputs(
prompt_token_ids=prompt_token_ids,
token_type_ids=token_type_ids,
)
return prompt_text, prompt_token_ids, token_type_ids, cache_salt
if cache_salt := parsed_content.get("cache_salt"):
inputs["cache_salt"] = cache_salt
return inputs
async def _process_tokens_async(
self,
parsed_content: TokensPrompt,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
) -> Union[TokenInputs, MultiModalInputs]:
prompt_token_ids = parsed_content["prompt_token_ids"]
token_type_ids = parsed_content.get("token_type_ids")
inputs: Union[TokenInputs, MultiModalInputs]
if multi_modal_data := parsed_content.get("multi_modal_data"):
inputs = await self._process_multimodal_async(
prompt_token_ids,
multi_modal_data,
parsed_content.get("mm_processor_kwargs"),
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
else:
inputs = token_inputs(
prompt_token_ids=prompt_token_ids,
token_type_ids=token_type_ids,
)
if cache_salt := parsed_content.get("cache_salt"):
inputs["cache_salt"] = cache_salt
return inputs
def _process_text(
self,
parsed_content: TextPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
) -> Union[TokenInputs, MultiModalInputs]:
prompt_text = parsed_content["prompt"]
inputs: Union[TokenInputs, MultiModalInputs]
if multi_modal_data := parsed_content.get("multi_modal_data"):
inputs = self._process_multimodal(
prompt_text,
multi_modal_data,
parsed_content.get("mm_processor_kwargs"),
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
else:
prompt_token_ids = self._tokenize_prompt(
prompt_text,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
inputs = token_inputs(
prompt=prompt_text,
prompt_token_ids=prompt_token_ids,
)
if cache_salt := parsed_content.get("cache_salt"):
inputs["cache_salt"] = cache_salt
return inputs
async def _process_text_async(
self,
parsed_content: TextPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
) -> Union[TokenInputs, MultiModalInputs]:
prompt_text = parsed_content["prompt"]
inputs: Union[TokenInputs, MultiModalInputs]
if multi_modal_data := parsed_content.get("multi_modal_data"):
inputs = await self._process_multimodal_async(
prompt_text,
multi_modal_data,
parsed_content.get("mm_processor_kwargs"),
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
else:
prompt_token_ids = await self._tokenize_prompt_async(
prompt_text,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
inputs = token_inputs(
prompt=prompt_text,
prompt_token_ids=prompt_token_ids,
)
if cache_salt := parsed_content.get("cache_salt"):
inputs["cache_salt"] = cache_salt
return inputs
def _prompt_to_llm_inputs(
self,
@ -333,38 +477,27 @@ class InputPreprocessor:
parsed = parse_singleton_prompt(prompt)
if parsed["type"] == "embeds":
return self._process_prompt_embeds(parsed)
prompt_text, prompt_token_ids, token_type_ids, cache_salt = \
self._get_prompt_data(parsed)
# If multimodal data is present, process and return immediately
if parsed["type"] != "str" and parsed["content"].get(
"multi_modal_data") is not None:
inputs = self._process_multimodal(
prompt_text if prompt_text is not None else prompt_token_ids,
parsed["content"]["multi_modal_data"],
parsed["content"].get("mm_processor_kwargs"),
return self._process_embeds(parsed["content"])
if parsed["type"] == "tokens":
return self._process_tokens(
parsed["content"],
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
if cache_salt is not None:
inputs["cache_salt"] = cache_salt
return inputs
if prompt_token_ids is None:
prompt_token_ids = self._tokenize_prompt(
prompt_text,
lora_request=lora_request,
if parsed["type"] == "text":
return self._process_text(
parsed["content"],
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
if parsed["type"] == "str":
return self._process_text(
TextPrompt(prompt=parsed["content"]),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
return token_inputs(
prompt=prompt_text,
prompt_token_ids=prompt_token_ids,
token_type_ids=token_type_ids,
cache_salt=cache_salt,
)
assert_never(parsed)
@ -375,79 +508,49 @@ class InputPreprocessor:
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
) -> SingletonInputs:
"""Async version of :meth:`_extract_prompt_components`."""
"""Async version of :meth:`_prompt_to_llm_inputs`."""
parsed = parse_singleton_prompt(prompt)
if parsed["type"] == "embeds":
return self._process_prompt_embeds(parsed)
prompt_text, prompt_token_ids, token_type_ids, cache_salt = \
self._get_prompt_data(parsed)
if parsed["type"] != "str" and parsed["content"].get(
"multi_modal_data") is not None:
inputs = await self._process_multimodal_async(
prompt_token_ids if prompt_text is None else prompt_text,
parsed["content"]["multi_modal_data"],
parsed["content"].get("mm_processor_kwargs"),
return await self._process_embeds_async(parsed["content"])
if parsed["type"] == "tokens":
return await self._process_tokens_async(
parsed["content"],
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
if cache_salt is not None:
inputs["cache_salt"] = cache_salt
return inputs
if prompt_token_ids is None:
prompt_token_ids = await self._tokenize_prompt_async(
prompt_text,
lora_request=lora_request,
if parsed["type"] == "text":
return await self._process_text_async(
parsed["content"],
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
if parsed["type"] == "str":
return await self._process_text_async(
TextPrompt(prompt=parsed["content"]),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
return token_inputs(
prompt=prompt_text,
prompt_token_ids=prompt_token_ids,
token_type_ids=token_type_ids,
cache_salt=cache_salt,
)
def _process_prompt_embeds(self,
parsed: ParsedEmbedsPrompt) -> EmbedsInputs:
if envs.VLLM_USE_V1:
raise ValueError("prompt_embeds is only available in V0.")
prompt_embeds_content = parsed["content"]
prompt_embeds = prompt_embeds_content["prompt_embeds"]
# prompt_embeds must be (seq_len, hidden_size), but if the user
# passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
# we can unambiguously process the intent by squeezing the batch
# dimension.
if prompt_embeds.ndim == 3 and prompt_embeds.shape[0] == 1:
prompt_embeds = prompt_embeds.squeeze(dim=0)
if prompt_embeds.ndim != 2:
raise ValueError(
"prompt_embeds must be of shape (seq_len, hidden_size).")
return embeds_inputs(prompt_embeds=prompt_embeds)
assert_never(parsed)
def _build_enc_dec_llm_inputs(
self,
encoder_inputs: Union[TokenInputs, MultiModalInputs],
decoder_inputs: Optional[Union[TokenInputs, MultiModalInputs]],
encoder_inputs: SingletonInputs,
decoder_inputs: Optional[SingletonInputs],
) -> EncoderDecoderInputs:
if (encoder_inputs["type"] == "token"
or encoder_inputs["type"] == "multimodal"):
pass
else:
assert_never(encoder_inputs) # type: ignore[arg-type]
if (encoder_inputs["type"] == "embeds"
or decoder_inputs and decoder_inputs["type"] == "embeds"):
raise ValueError("Embedding inputs are not supported for encoder-"
"decoder models")
# Mypy does not correctly infer that EmbedsInputs is impossible
assert "prompt_token_ids" in encoder_inputs
# Needed for mypy
encoder_inputs = cast(Union[TokenInputs, MultiModalInputs],
encoder_inputs)
decoder_inputs = cast(Optional[Union[TokenInputs, MultiModalInputs]],
decoder_inputs)
if decoder_inputs is None:
if self.model_config.hf_config.model_type == "whisper":
@ -460,74 +563,78 @@ class InputPreprocessor:
dec_token_ids = self._prepare_decoder_input_ids_for_generation(
None)
decoder_inputs = token_inputs(dec_token_ids)
elif (decoder_inputs["type"] == "token"
or decoder_inputs["type"] == "multimodal"):
dec_token_ids = self._prepare_decoder_input_ids_for_generation(
decoder_inputs["prompt_token_ids"])
decoder_inputs["prompt_token_ids"] = dec_token_ids
else:
if "multi_modal_data" in decoder_inputs:
raise ValueError("Multi-modal decoder inputs of encoder-"
"decoder models are not supported yet")
else:
assert_never(encoder_inputs) # type: ignore[arg-type]
dec_token_ids = self._prepare_decoder_input_ids_for_generation(
decoder_inputs["prompt_token_ids"])
decoder_inputs["prompt_token_ids"] = dec_token_ids
return EncoderDecoderInputs(
encoder=encoder_inputs,
decoder=decoder_inputs,
)
def _separate_enc_dec_inputs_from_mm_processor_outputs(
def _split_enc_dec_mm_inputs(
self,
inputs: SingletonInputs,
decoder_inputs_to_override: Optional[Union[TokenInputs,
MultiModalInputs]] = None,
inputs: Union[SingletonInputs, MultiModalEncDecInputs],
decoder_inputs_to_override: Optional[SingletonInputs] = None,
) -> tuple[SingletonInputs, SingletonInputs]:
"""
For encoder/decoder models only:
Separate Encoder/Decoder inputs from a MultiModalEncDecInputs
"""
if (inputs["type"] == "embeds" or decoder_inputs_to_override
and decoder_inputs_to_override["type"] == "embeds"):
raise ValueError("Embedding inputs are not supported for encoder-"
"decoder models")
# Needed for mypy
inputs = cast(
Union[TokenInputs, MultiModalInputs, MultiModalEncDecInputs],
inputs,
)
decoder_inputs_to_override = cast(
Optional[Union[TokenInputs, MultiModalInputs]],
decoder_inputs_to_override,
)
encoder_inputs: SingletonInputs
decoder_inputs: SingletonInputs
if inputs["type"] == "multimodal":
# Multimodal data inputs
assert ("encoder_prompt" in inputs
and "encoder_prompt_token_ids" in inputs)
if inputs["type"] == "multimodal": # Multimodal data inputs
if not ("encoder_prompt" in inputs
and "encoder_prompt_token_ids" in inputs):
raise RuntimeError("You should register an encoder-decoder "
"multi-modal processor for encoder-decoder "
"models.")
inputs = cast(MultiModalEncDecInputs, inputs)
encoder_inputs = token_inputs(
prompt=inputs["encoder_prompt"],
prompt_token_ids=inputs["encoder_prompt_token_ids"],
)
if decoder_inputs_to_override is not None:
decoder_inputs = MultiModalInputs(
type="multimodal",
prompt=decoder_inputs_to_override.get("prompt", ""),
prompt_token_ids=decoder_inputs_to_override[
"prompt_token_ids"],
mm_kwargs=inputs["mm_kwargs"],
mm_hashes=inputs["mm_hashes"],
mm_placeholders=inputs["mm_placeholders"],
)
else:
decoder_inputs = MultiModalInputs(
type="multimodal",
prompt=inputs["prompt"],
prompt_token_ids=inputs["prompt_token_ids"],
mm_kwargs=inputs["mm_kwargs"],
mm_hashes=inputs["mm_hashes"],
mm_placeholders=inputs["mm_placeholders"],
)
cache_salt = inputs.get("cache_salt")
if cache_salt is not None:
decoder_prompt_inputs = decoder_inputs_to_override or inputs
decoder_inputs = MultiModalInputs(
type="multimodal",
prompt=decoder_prompt_inputs.get("prompt", ""),
prompt_token_ids=decoder_prompt_inputs["prompt_token_ids"],
mm_kwargs=inputs["mm_kwargs"],
mm_hashes=inputs["mm_hashes"],
mm_placeholders=inputs["mm_placeholders"],
)
if cache_salt := inputs.get("cache_salt"):
decoder_inputs["cache_salt"] = cache_salt
elif inputs["type"] == "token":
# Text-only inputs
elif inputs["type"] == "token": # Text-only inputs
encoder_inputs = token_inputs(prompt="", prompt_token_ids=[])
decoder_inputs = decoder_inputs_to_override or inputs
else:
assert_never(inputs) # type: ignore[arg-type]
return encoder_inputs, decoder_inputs
def _process_encoder_decoder_prompt(
@ -580,11 +687,9 @@ class InputPreprocessor:
# For multimodal model, override decoder prompt from processor
# with explicit decoder prompt.
if self.model_config.is_multimodal_model:
assert decoder_inputs is None or not is_embeds_inputs(
decoder_inputs)
encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs(
encoder_inputs, decoder_inputs))
self._split_enc_dec_mm_inputs(encoder_inputs,
decoder_inputs))
else:
inputs = self._prompt_to_llm_inputs(
prompt,
@ -593,16 +698,11 @@ class InputPreprocessor:
if self.model_config.is_multimodal_model:
# Encoder-Decoder Multimodal model
encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs(
inputs))
self._split_enc_dec_mm_inputs(inputs))
else:
encoder_inputs = inputs
decoder_inputs = None
# Mypy does not do type inference well with TypedDicts with Literal
# values.
assert not is_embeds_inputs(encoder_inputs)
assert decoder_inputs is None or not is_embeds_inputs(decoder_inputs)
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
async def _process_encoder_decoder_prompt_async(
@ -635,11 +735,9 @@ class InputPreprocessor:
# For multimodal model, override decoder prompt from processor
# with explicit decoder prompt.
if self.model_config.is_multimodal_model:
assert decoder_inputs is None or not is_embeds_inputs(
decoder_inputs)
encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs(
encoder_inputs, decoder_inputs))
self._split_enc_dec_mm_inputs(encoder_inputs,
decoder_inputs))
else:
inputs = await self._prompt_to_llm_inputs_async(
prompt,
@ -648,16 +746,11 @@ class InputPreprocessor:
if self.model_config.is_multimodal_model:
# Encoder-Decoder Multimodal model
encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs(
inputs))
self._split_enc_dec_mm_inputs(inputs))
else:
encoder_inputs = inputs
decoder_inputs = None
# Mypy does not do type inference well with TypedDicts with Literal
# values.
assert not is_embeds_inputs(encoder_inputs)
assert decoder_inputs is None or not is_embeds_inputs(decoder_inputs)
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
def _build_decoder_only_llm_inputs(
@ -665,19 +758,13 @@ class InputPreprocessor:
prompt_inputs: DecoderOnlyInputs,
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> DecoderOnlyInputs:
if (prompt_inputs["type"] == "token"
or prompt_inputs["type"] == "multimodal"):
# Mypy does not do type inference well with typedicts and Literal
# values
assert not is_embeds_inputs(prompt_inputs)
if "prompt_token_ids" in prompt_inputs:
prompt_inputs = cast(Union[TokenInputs, MultiModalInputs],
prompt_inputs) # Needed for mypy
prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter(
prompt_inputs["prompt_token_ids"],
prompt_adapter_request=prompt_adapter_request,
)
elif (prompt_inputs["type"] == "embeds"):
pass
else:
assert_never(prompt_inputs) # type: ignore[arg-type]
return prompt_inputs

View File

@ -1670,15 +1670,17 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
placeholders = mm_placeholders.get(modality, [])
if len(placeholders) != item_count:
# NOTE: If you are a model developer, this can also arise from
# an inconsistency between `_call_hf_processor` and
# `_get_mm_fields_config` implementations
raise RuntimeError(
f"Expected there to be {item_count} prompt updates "
f"corresponding to {item_count} {modality} items, but "
f"instead found {len(placeholders)} prompt updates! "
"Either the prompt text has missing/incorrect tokens for "
"multi-modal inputs, or there is a problem with your "
"implementation of merged multi-modal processor for this "
"model (usually arising from an inconsistency between "
"`_call_hf_processor` and `_get_prompt_updates`).")
"This is likely because you forgot to include input "
"placeholder tokens (e.g., `<image>`, `<|image_pad|>`) "
"in the prompt. If the model has a chat template, make "
"sure you have applied it before calling `LLM.generate`.")
def _maybe_apply_prompt_updates(
self,