mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-27 22:17:40 +08:00
[Misc] Clean up input processing (#17582)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
3a500cd0b6
commit
cb234955df
@ -6,6 +6,7 @@ from huggingface_hub import snapshot_download
|
||||
from transformers import AutoConfig, AutoModel, CLIPImageProcessor
|
||||
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
from ....conftest import ImageTestAssets
|
||||
|
||||
@ -14,6 +15,7 @@ from ....conftest import ImageTestAssets
|
||||
DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"]
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_intern_vit_test(
|
||||
image_assets: ImageTestAssets,
|
||||
model_id: str,
|
||||
@ -21,11 +23,12 @@ def run_intern_vit_test(
|
||||
dtype: str,
|
||||
):
|
||||
model = snapshot_download(model_id, allow_patterns=DOWNLOAD_PATTERN)
|
||||
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||
|
||||
img_processor = CLIPImageProcessor.from_pretrained(model)
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
pixel_values = [
|
||||
img_processor(images, return_tensors='pt').pixel_values.to(dtype)
|
||||
img_processor(images, return_tensors='pt').pixel_values.to(torch_dtype)
|
||||
for images in images
|
||||
]
|
||||
|
||||
@ -34,7 +37,7 @@ def run_intern_vit_test(
|
||||
config.norm_type = "rms_norm"
|
||||
|
||||
hf_model = AutoModel.from_pretrained(model,
|
||||
torch_dtype=dtype,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=True).to("cuda")
|
||||
hf_outputs_per_image = [
|
||||
hf_model(pixel_value.to("cuda")).last_hidden_state
|
||||
@ -48,7 +51,7 @@ def run_intern_vit_test(
|
||||
del hf_model
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
vllm_model = vllm_model.to("cuda", dtype)
|
||||
vllm_model = vllm_model.to("cuda", torch_dtype)
|
||||
vllm_outputs_per_image = [
|
||||
vllm_model(pixel_values=pixel_value.to("cuda"))
|
||||
for pixel_value in pixel_values
|
||||
@ -66,9 +69,8 @@ def run_intern_vit_test(
|
||||
"OpenGVLab/InternViT-300M-448px",
|
||||
"OpenGVLab/InternViT-6B-448px-V1-5",
|
||||
])
|
||||
@pytest.mark.parametrize("dtype", [torch.half])
|
||||
@torch.inference_mode()
|
||||
def test_models(image_assets, model_id, dtype: str) -> None:
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_models(dist_init, image_assets, model_id, dtype: str) -> None:
|
||||
run_intern_vit_test(
|
||||
image_assets,
|
||||
model_id,
|
||||
|
||||
@ -497,10 +497,6 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
prompt["prompt_token_ids"] = [0
|
||||
] * prompt["prompt_embeds"].shape[-2]
|
||||
|
||||
if self.tokenizer is not None:
|
||||
tokenizer = await self.get_tokenizer_async(lora_request)
|
||||
self._validate_token_prompt(prompt, tokenizer=tokenizer)
|
||||
|
||||
processed_inputs = await self.input_preprocessor.preprocess_async(
|
||||
prompt,
|
||||
lora_request=lora_request,
|
||||
|
||||
@ -30,7 +30,7 @@ from vllm.entrypoints.openai.logits_processors import (
|
||||
get_logits_processors as get_openai_logits_processors)
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
|
||||
from vllm.inputs.parse import is_token_prompt, split_enc_dec_inputs
|
||||
from vllm.inputs.parse import split_enc_dec_inputs
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logits_process import get_bad_words_logits_processors
|
||||
@ -759,11 +759,6 @@ class LLMEngine:
|
||||
seq_len = prompt["prompt_embeds"].shape[0]
|
||||
prompt["prompt_token_ids"] = [0] * seq_len
|
||||
|
||||
if self.tokenizer is not None:
|
||||
self._validate_token_prompt(
|
||||
prompt,
|
||||
tokenizer=self.get_tokenizer(lora_request=lora_request))
|
||||
|
||||
processed_inputs = self.input_preprocessor.preprocess(
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
@ -782,27 +777,6 @@ class LLMEngine:
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
def _validate_token_prompt(self, prompt: PromptType,
|
||||
tokenizer: AnyTokenizer):
|
||||
# Guard against out-of-vocab tokens.
|
||||
# For some tokenizers, tokenizer.decode will happily return empty text
|
||||
# for token ids that are out of vocab, and we don't detect token ids
|
||||
# that are greater than the max token id before running the model.
|
||||
# However, these token ids will later crash a cuda kernel at runtime
|
||||
# with an index out of bounds error. This will crash the entire engine.
|
||||
# This needs to happen before multimodal input pre-processing, which
|
||||
# may add dummy <image> tokens that aren't part of the tokenizer's
|
||||
# vocabulary.
|
||||
if is_token_prompt(prompt):
|
||||
prompt_ids = prompt["prompt_token_ids"]
|
||||
if len(prompt_ids) == 0:
|
||||
# Empty prompt check is handled later
|
||||
return
|
||||
max_input_id = max(prompt_ids)
|
||||
if max_input_id > tokenizer.max_token_id:
|
||||
raise ValueError(
|
||||
"Token id {} is out of vocabulary".format(max_input_id))
|
||||
|
||||
def _create_sequence_group_with_sampling(
|
||||
self,
|
||||
request_id: str,
|
||||
@ -2049,6 +2023,12 @@ class LLMEngine:
|
||||
else:
|
||||
raise ValueError(f"The {prompt_type} prompt cannot be empty")
|
||||
|
||||
if tokenizer is not None:
|
||||
max_input_id = max(prompt_ids, default=0)
|
||||
if max_input_id > tokenizer.max_token_id:
|
||||
raise ValueError(
|
||||
f"Token id {max_input_id} is out of vocabulary")
|
||||
|
||||
max_prompt_len = self.model_config.max_model_len
|
||||
if len(prompt_ids) > max_prompt_len:
|
||||
if prompt_type == "encoder" and model_config.is_multimodal_model:
|
||||
|
||||
@ -83,6 +83,9 @@ class EngineClient(ABC):
|
||||
else:
|
||||
processed_inputs = preprocessor._prompt_to_llm_inputs(prompt)
|
||||
|
||||
if processed_inputs["type"] == "embeds":
|
||||
raise NotImplementedError
|
||||
|
||||
prompt_token_ids = processed_inputs["prompt_token_ids"]
|
||||
prompt_text = processed_inputs.get("prompt")
|
||||
multi_modal_data = processed_inputs.get("multi_modal_data")
|
||||
|
||||
@ -27,7 +27,7 @@ from vllm.entrypoints.score_utils import (_cosine_similarity,
|
||||
_validate_score_input_lens)
|
||||
from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
|
||||
from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt
|
||||
from vllm.inputs.parse import parse_and_batch_prompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.guided_decoding.guided_fields import (
|
||||
@ -567,10 +567,12 @@ class LLM:
|
||||
mm_kwargs["mm_processor_kwargs"] = prompt[
|
||||
"mm_processor_kwargs"]
|
||||
|
||||
if is_token_prompt(prompt):
|
||||
if "prompt_token_ids" in prompt:
|
||||
prompt = cast(TokensPrompt, prompt) # Needed for mypy
|
||||
prompt_tokens = prompt["prompt_token_ids"]
|
||||
else:
|
||||
prompt_tokens = tokenizer.encode(prompt["prompt"])
|
||||
|
||||
instances.append(
|
||||
BeamSearchInstance(prompt_tokens, logprobs=None, **mm_kwargs))
|
||||
|
||||
|
||||
@ -70,6 +70,11 @@ class EmbedsPrompt(TypedDict):
|
||||
prompt_embeds: torch.Tensor
|
||||
"""The embeddings of the prompt."""
|
||||
|
||||
cache_salt: NotRequired[str]
|
||||
"""
|
||||
Optional cache salt to be used for prefix caching.
|
||||
"""
|
||||
|
||||
|
||||
SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt]
|
||||
"""
|
||||
@ -195,13 +200,21 @@ class EmbedsInputs(TypedDict):
|
||||
prompt_embeds: torch.Tensor
|
||||
"""The embeddings of the prompt."""
|
||||
|
||||
cache_salt: NotRequired[str]
|
||||
"""
|
||||
Optional cache salt to be used for prefix caching.
|
||||
"""
|
||||
|
||||
def embeds_inputs(prompt_embeds: torch.Tensor) -> EmbedsInputs:
|
||||
|
||||
def embeds_inputs(
|
||||
prompt_embeds: torch.Tensor,
|
||||
cache_salt: Optional[str] = None,
|
||||
) -> EmbedsInputs:
|
||||
"""Construct :class:`EmbedsInputs` from optional values."""
|
||||
inputs = EmbedsInputs(
|
||||
type="embeds",
|
||||
prompt_embeds=prompt_embeds,
|
||||
)
|
||||
inputs = EmbedsInputs(type="embeds", prompt_embeds=prompt_embeds)
|
||||
|
||||
if cache_salt is not None:
|
||||
inputs["cache_salt"] = cache_salt
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
@ -6,9 +6,9 @@ from typing_extensions import TypeIs
|
||||
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .data import (EmbedsInputs, EmbedsPrompt, ExplicitEncoderDecoderPrompt,
|
||||
ProcessorInputs, PromptType, SingletonInputs,
|
||||
SingletonPrompt, TextPrompt, TokensPrompt)
|
||||
from .data import (EmbedsPrompt, ExplicitEncoderDecoderPrompt, ProcessorInputs,
|
||||
PromptType, SingletonInputs, SingletonPrompt, TextPrompt,
|
||||
TokensPrompt)
|
||||
|
||||
|
||||
class ParsedText(TypedDict):
|
||||
@ -90,6 +90,10 @@ class ParsedEmbedsPrompt(TypedDict):
|
||||
content: EmbedsPrompt
|
||||
|
||||
|
||||
ParsedSingletonPrompt = Union[ParsedStrPrompt, ParsedTextPrompt,
|
||||
ParsedTokensPrompt, ParsedEmbedsPrompt]
|
||||
|
||||
|
||||
@overload
|
||||
def parse_singleton_prompt(prompt: str) -> ParsedStrPrompt:
|
||||
...
|
||||
@ -110,10 +114,7 @@ def parse_singleton_prompt(prompt: EmbedsPrompt) -> ParsedEmbedsPrompt:
|
||||
...
|
||||
|
||||
|
||||
def parse_singleton_prompt(
|
||||
prompt: SingletonPrompt,
|
||||
) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt,
|
||||
ParsedEmbedsPrompt]:
|
||||
def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt:
|
||||
if isinstance(prompt, str):
|
||||
return ParsedStrPrompt(type="str", content=prompt)
|
||||
elif isinstance(prompt, dict):
|
||||
@ -131,23 +132,11 @@ def parse_singleton_prompt(
|
||||
"inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt")
|
||||
|
||||
|
||||
def is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]:
|
||||
return isinstance(prompt, dict) and "prompt_token_ids" in prompt
|
||||
|
||||
|
||||
def is_embeds_prompt(prompt: PromptType) -> TypeIs[EmbedsPrompt]:
|
||||
return isinstance(prompt, dict) and "prompt_embeds" in prompt
|
||||
|
||||
|
||||
def is_explicit_encoder_decoder_prompt(
|
||||
prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]:
|
||||
return isinstance(prompt, dict) and "encoder_prompt" in prompt
|
||||
|
||||
|
||||
def is_embeds_inputs(inputs: SingletonInputs) -> TypeIs[EmbedsInputs]:
|
||||
return isinstance(inputs, dict) and inputs["type"] == "embeds"
|
||||
|
||||
|
||||
def split_enc_dec_inputs(
|
||||
inputs: ProcessorInputs,
|
||||
) -> tuple[Optional[SingletonInputs], SingletonInputs]:
|
||||
|
||||
@ -14,14 +14,14 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
|
||||
MultiModalInputs)
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
|
||||
from .data import (DecoderOnlyInputs, EmbedsInputs, EncoderDecoderInputs,
|
||||
ProcessorInputs, PromptType, SingletonInputs,
|
||||
SingletonPrompt, TokenInputs, embeds_inputs, token_inputs)
|
||||
from .parse import (ParsedEmbedsPrompt, ParsedStrPrompt, ParsedTextPrompt,
|
||||
ParsedTokensPrompt, is_embeds_inputs,
|
||||
is_explicit_encoder_decoder_prompt, parse_singleton_prompt)
|
||||
from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
|
||||
EncoderDecoderInputs, ProcessorInputs, PromptType,
|
||||
SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs,
|
||||
TokensPrompt, embeds_inputs, token_inputs)
|
||||
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -140,13 +140,10 @@ class InputPreprocessor:
|
||||
"""
|
||||
Prepares `decoder_input_ids` for generation with encoder-decoder models.
|
||||
|
||||
Based on
|
||||
|
||||
https://github.com/huggingface/transformers/blob/
|
||||
4037a2b5b1278736e566aec12e169100275545ea/
|
||||
src/transformers/generation/utils.py
|
||||
|
||||
specifically GenerationMixin._prepare_decoder_input_ids_for_generation()
|
||||
Based on:
|
||||
https://github.com/huggingface/transformers/blob/4037a2b5b1278736e566aec12e169100275545ea/src/transformers/generation/utils.py
|
||||
specifically,
|
||||
`GenerationMixin._prepare_decoder_input_ids_for_generation()`.
|
||||
|
||||
Arguments:
|
||||
|
||||
@ -183,6 +180,23 @@ class InputPreprocessor:
|
||||
|
||||
return prompt_token_ids
|
||||
|
||||
def _get_tokenization_kw(
|
||||
self,
|
||||
overrides: Optional[dict[str, Any]] = None,
|
||||
) -> dict[str, Any]:
|
||||
kwargs = dict[str, Any]()
|
||||
|
||||
if self.model_config.hf_config.model_type == "whisper":
|
||||
# For Whisper, special tokens should be provided by the user based
|
||||
# on the task and language of their request. Also needed to avoid
|
||||
# appending an EOS token to the prompt which disrupts generation.
|
||||
kwargs["add_special_tokens"] = False
|
||||
|
||||
if overrides:
|
||||
kwargs.update(overrides)
|
||||
|
||||
return kwargs
|
||||
|
||||
def _tokenize_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
@ -194,18 +208,11 @@ class InputPreprocessor:
|
||||
corresponding token IDs.
|
||||
"""
|
||||
tokenizer = self.get_tokenizer_group()
|
||||
if tokenization_kwargs is None:
|
||||
tokenization_kwargs = {}
|
||||
tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
|
||||
|
||||
if self.model_config.hf_config.model_type == "whisper":
|
||||
# For Whisper, special tokens should be provided by the user based
|
||||
# on the task and language of their request. Also needed to avoid
|
||||
# appending an EOS token to the prompt which disrupts generation.
|
||||
tokenization_kwargs["add_special_tokens"] = False
|
||||
encoder_config = self.model_config.encoder_config
|
||||
|
||||
if (self.model_config.encoder_config is not None
|
||||
and self.model_config.encoder_config.get(
|
||||
"do_lower_case", False)):
|
||||
if encoder_config and encoder_config.get("do_lower_case", False):
|
||||
prompt = prompt.lower()
|
||||
|
||||
return tokenizer.encode(prompt=prompt,
|
||||
@ -220,18 +227,36 @@ class InputPreprocessor:
|
||||
) -> list[int]:
|
||||
"""Async version of :meth:`_tokenize_prompt`."""
|
||||
tokenizer = self.get_tokenizer_group()
|
||||
if tokenization_kwargs is None:
|
||||
tokenization_kwargs = {}
|
||||
tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
|
||||
|
||||
if self.model_config.hf_config.model_type == "whisper":
|
||||
# For Whisper, special tokens should be provided by the user based
|
||||
# on the task and language of their request. Also needed to avoid
|
||||
# appending an EOS token to the prompt which disrupts generation.
|
||||
tokenization_kwargs["add_special_tokens"] = False
|
||||
return await tokenizer.encode_async(prompt=prompt,
|
||||
lora_request=lora_request,
|
||||
**tokenization_kwargs)
|
||||
|
||||
def _get_mm_tokenizer(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest],
|
||||
) -> AnyTokenizer:
|
||||
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer
|
||||
# while using also multi-modal input
|
||||
if not self.tokenizer:
|
||||
return cast(AnyTokenizer, object()) # Dummy
|
||||
|
||||
tokenizer_group = self.get_tokenizer_group()
|
||||
return tokenizer_group.get_lora_tokenizer(lora_request)
|
||||
|
||||
async def _get_mm_tokenizer_async(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest],
|
||||
) -> AnyTokenizer:
|
||||
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer
|
||||
# while using also multi-modal input
|
||||
if not self.tokenizer:
|
||||
return cast(AnyTokenizer, object()) # Dummy
|
||||
|
||||
tokenizer_group = self.get_tokenizer_group()
|
||||
return await tokenizer_group.get_lora_tokenizer_async(lora_request)
|
||||
|
||||
def _process_multimodal(
|
||||
self,
|
||||
prompt: Union[str, list[int]],
|
||||
@ -244,13 +269,7 @@ class InputPreprocessor:
|
||||
Apply the model's multi-modal processor to a multi-modal prompt,
|
||||
returning the corresponding token IDs and metadata.
|
||||
"""
|
||||
# At the moment on model (PrithviGeoSpatialMAE) requires to be
|
||||
# initialized without a tokenizer while using also multi-modal input
|
||||
if not self.tokenizer:
|
||||
tokenizer = object() # Dummy
|
||||
else:
|
||||
tokenizer_group = self.get_tokenizer_group()
|
||||
tokenizer = tokenizer_group.get_lora_tokenizer(lora_request)
|
||||
tokenizer = self._get_mm_tokenizer(lora_request)
|
||||
|
||||
mm_processor = self.mm_registry.create_processor(self.model_config,
|
||||
tokenizer=tokenizer)
|
||||
@ -270,14 +289,7 @@ class InputPreprocessor:
|
||||
return_mm_hashes: bool = False,
|
||||
) -> MultiModalInputs:
|
||||
"""Async version of :meth:`_process_multimodal`."""
|
||||
# At the moment on model (PrithviGeoSpatialMAE) requires to be
|
||||
# initialized without a tokenizer while using also multi-modal input
|
||||
if not self.tokenizer:
|
||||
tokenizer = object() # Dummy
|
||||
else:
|
||||
tokenizer_group = self.get_tokenizer_group()
|
||||
tokenizer = await tokenizer_group.get_lora_tokenizer_async(
|
||||
lora_request)
|
||||
tokenizer = await self._get_mm_tokenizer_async(lora_request)
|
||||
|
||||
mm_processor = self.mm_registry.create_processor(self.model_config,
|
||||
tokenizer=tokenizer)
|
||||
@ -287,28 +299,160 @@ class InputPreprocessor:
|
||||
return mm_processor.apply(prompt, mm_data, mm_processor_kwargs,
|
||||
return_mm_hashes)
|
||||
|
||||
def _get_prompt_data(self, parsed_prompt: Union[ParsedStrPrompt,
|
||||
ParsedTextPrompt,
|
||||
ParsedTokensPrompt]):
|
||||
prompt_text = None
|
||||
prompt_token_ids = None
|
||||
token_type_ids = None
|
||||
cache_salt = None
|
||||
def _process_embeds(
|
||||
self,
|
||||
parsed_content: EmbedsPrompt,
|
||||
) -> EmbedsInputs:
|
||||
if envs.VLLM_USE_V1:
|
||||
raise ValueError("prompt_embeds is only available in V0.")
|
||||
|
||||
if parsed_prompt["type"] == "str":
|
||||
prompt_text = parsed_prompt["content"]
|
||||
prompt_embeds = parsed_content["prompt_embeds"]
|
||||
|
||||
# prompt_embeds must be (seq_len, hidden_size), but if the user
|
||||
# passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
|
||||
# we can unambiguously process the intent by squeezing the batch
|
||||
# dimension.
|
||||
if prompt_embeds.ndim == 3:
|
||||
prompt_embeds = prompt_embeds.squeeze(dim=0)
|
||||
|
||||
if prompt_embeds.ndim != 2:
|
||||
raise ValueError(
|
||||
"prompt_embeds must be of shape (seq_len, hidden_size).")
|
||||
|
||||
return embeds_inputs(prompt_embeds=prompt_embeds,
|
||||
cache_salt=parsed_content.get("cache_salt"))
|
||||
|
||||
async def _process_embeds_async(
|
||||
self,
|
||||
parsed_content: EmbedsPrompt,
|
||||
) -> EmbedsInputs:
|
||||
return self._process_embeds(parsed_content)
|
||||
|
||||
def _process_tokens(
|
||||
self,
|
||||
parsed_content: TokensPrompt,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
return_mm_hashes: bool = False,
|
||||
) -> Union[TokenInputs, MultiModalInputs]:
|
||||
prompt_token_ids = parsed_content["prompt_token_ids"]
|
||||
token_type_ids = parsed_content.get("token_type_ids")
|
||||
|
||||
inputs: Union[TokenInputs, MultiModalInputs]
|
||||
if multi_modal_data := parsed_content.get("multi_modal_data"):
|
||||
inputs = self._process_multimodal(
|
||||
prompt_token_ids,
|
||||
multi_modal_data,
|
||||
parsed_content.get("mm_processor_kwargs"),
|
||||
lora_request=lora_request,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
)
|
||||
else:
|
||||
cache_salt = parsed_prompt["content"].get("cache_salt")
|
||||
if parsed_prompt["type"] == "text":
|
||||
prompt_text = parsed_prompt["content"]["prompt"]
|
||||
elif parsed_prompt["type"] == "tokens":
|
||||
prompt_token_ids = parsed_prompt["content"].get(
|
||||
"prompt_token_ids")
|
||||
token_type_ids = parsed_prompt["content"].get("token_type_ids")
|
||||
else:
|
||||
assert_never(parsed_prompt)
|
||||
inputs = token_inputs(
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
)
|
||||
|
||||
return prompt_text, prompt_token_ids, token_type_ids, cache_salt
|
||||
if cache_salt := parsed_content.get("cache_salt"):
|
||||
inputs["cache_salt"] = cache_salt
|
||||
|
||||
return inputs
|
||||
|
||||
async def _process_tokens_async(
|
||||
self,
|
||||
parsed_content: TokensPrompt,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
return_mm_hashes: bool = False,
|
||||
) -> Union[TokenInputs, MultiModalInputs]:
|
||||
prompt_token_ids = parsed_content["prompt_token_ids"]
|
||||
token_type_ids = parsed_content.get("token_type_ids")
|
||||
|
||||
inputs: Union[TokenInputs, MultiModalInputs]
|
||||
if multi_modal_data := parsed_content.get("multi_modal_data"):
|
||||
inputs = await self._process_multimodal_async(
|
||||
prompt_token_ids,
|
||||
multi_modal_data,
|
||||
parsed_content.get("mm_processor_kwargs"),
|
||||
lora_request=lora_request,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
)
|
||||
else:
|
||||
inputs = token_inputs(
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
)
|
||||
|
||||
if cache_salt := parsed_content.get("cache_salt"):
|
||||
inputs["cache_salt"] = cache_salt
|
||||
|
||||
return inputs
|
||||
|
||||
def _process_text(
|
||||
self,
|
||||
parsed_content: TextPrompt,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
return_mm_hashes: bool = False,
|
||||
) -> Union[TokenInputs, MultiModalInputs]:
|
||||
prompt_text = parsed_content["prompt"]
|
||||
|
||||
inputs: Union[TokenInputs, MultiModalInputs]
|
||||
if multi_modal_data := parsed_content.get("multi_modal_data"):
|
||||
inputs = self._process_multimodal(
|
||||
prompt_text,
|
||||
multi_modal_data,
|
||||
parsed_content.get("mm_processor_kwargs"),
|
||||
lora_request=lora_request,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
)
|
||||
else:
|
||||
prompt_token_ids = self._tokenize_prompt(
|
||||
prompt_text,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
inputs = token_inputs(
|
||||
prompt=prompt_text,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
)
|
||||
|
||||
if cache_salt := parsed_content.get("cache_salt"):
|
||||
inputs["cache_salt"] = cache_salt
|
||||
|
||||
return inputs
|
||||
|
||||
async def _process_text_async(
|
||||
self,
|
||||
parsed_content: TextPrompt,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
return_mm_hashes: bool = False,
|
||||
) -> Union[TokenInputs, MultiModalInputs]:
|
||||
prompt_text = parsed_content["prompt"]
|
||||
|
||||
inputs: Union[TokenInputs, MultiModalInputs]
|
||||
if multi_modal_data := parsed_content.get("multi_modal_data"):
|
||||
inputs = await self._process_multimodal_async(
|
||||
prompt_text,
|
||||
multi_modal_data,
|
||||
parsed_content.get("mm_processor_kwargs"),
|
||||
lora_request=lora_request,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
)
|
||||
else:
|
||||
prompt_token_ids = await self._tokenize_prompt_async(
|
||||
prompt_text,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
inputs = token_inputs(
|
||||
prompt=prompt_text,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
)
|
||||
|
||||
if cache_salt := parsed_content.get("cache_salt"):
|
||||
inputs["cache_salt"] = cache_salt
|
||||
|
||||
return inputs
|
||||
|
||||
def _prompt_to_llm_inputs(
|
||||
self,
|
||||
@ -333,38 +477,27 @@ class InputPreprocessor:
|
||||
parsed = parse_singleton_prompt(prompt)
|
||||
|
||||
if parsed["type"] == "embeds":
|
||||
return self._process_prompt_embeds(parsed)
|
||||
|
||||
prompt_text, prompt_token_ids, token_type_ids, cache_salt = \
|
||||
self._get_prompt_data(parsed)
|
||||
|
||||
# If multimodal data is present, process and return immediately
|
||||
if parsed["type"] != "str" and parsed["content"].get(
|
||||
"multi_modal_data") is not None:
|
||||
inputs = self._process_multimodal(
|
||||
prompt_text if prompt_text is not None else prompt_token_ids,
|
||||
parsed["content"]["multi_modal_data"],
|
||||
parsed["content"].get("mm_processor_kwargs"),
|
||||
return self._process_embeds(parsed["content"])
|
||||
if parsed["type"] == "tokens":
|
||||
return self._process_tokens(
|
||||
parsed["content"],
|
||||
lora_request=lora_request,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
)
|
||||
if cache_salt is not None:
|
||||
inputs["cache_salt"] = cache_salt
|
||||
return inputs
|
||||
|
||||
if prompt_token_ids is None:
|
||||
prompt_token_ids = self._tokenize_prompt(
|
||||
prompt_text,
|
||||
lora_request=lora_request,
|
||||
if parsed["type"] == "text":
|
||||
return self._process_text(
|
||||
parsed["content"],
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
)
|
||||
if parsed["type"] == "str":
|
||||
return self._process_text(
|
||||
TextPrompt(prompt=parsed["content"]),
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
)
|
||||
|
||||
return token_inputs(
|
||||
prompt=prompt_text,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
cache_salt=cache_salt,
|
||||
)
|
||||
|
||||
assert_never(parsed)
|
||||
|
||||
@ -375,79 +508,49 @@ class InputPreprocessor:
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
return_mm_hashes: bool = False,
|
||||
) -> SingletonInputs:
|
||||
"""Async version of :meth:`_extract_prompt_components`."""
|
||||
"""Async version of :meth:`_prompt_to_llm_inputs`."""
|
||||
parsed = parse_singleton_prompt(prompt)
|
||||
|
||||
if parsed["type"] == "embeds":
|
||||
return self._process_prompt_embeds(parsed)
|
||||
|
||||
prompt_text, prompt_token_ids, token_type_ids, cache_salt = \
|
||||
self._get_prompt_data(parsed)
|
||||
|
||||
if parsed["type"] != "str" and parsed["content"].get(
|
||||
"multi_modal_data") is not None:
|
||||
inputs = await self._process_multimodal_async(
|
||||
prompt_token_ids if prompt_text is None else prompt_text,
|
||||
parsed["content"]["multi_modal_data"],
|
||||
parsed["content"].get("mm_processor_kwargs"),
|
||||
return await self._process_embeds_async(parsed["content"])
|
||||
if parsed["type"] == "tokens":
|
||||
return await self._process_tokens_async(
|
||||
parsed["content"],
|
||||
lora_request=lora_request,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
)
|
||||
if cache_salt is not None:
|
||||
inputs["cache_salt"] = cache_salt
|
||||
return inputs
|
||||
|
||||
if prompt_token_ids is None:
|
||||
prompt_token_ids = await self._tokenize_prompt_async(
|
||||
prompt_text,
|
||||
lora_request=lora_request,
|
||||
if parsed["type"] == "text":
|
||||
return await self._process_text_async(
|
||||
parsed["content"],
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
)
|
||||
if parsed["type"] == "str":
|
||||
return await self._process_text_async(
|
||||
TextPrompt(prompt=parsed["content"]),
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
)
|
||||
|
||||
return token_inputs(
|
||||
prompt=prompt_text,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
cache_salt=cache_salt,
|
||||
)
|
||||
|
||||
def _process_prompt_embeds(self,
|
||||
parsed: ParsedEmbedsPrompt) -> EmbedsInputs:
|
||||
if envs.VLLM_USE_V1:
|
||||
raise ValueError("prompt_embeds is only available in V0.")
|
||||
|
||||
prompt_embeds_content = parsed["content"]
|
||||
|
||||
prompt_embeds = prompt_embeds_content["prompt_embeds"]
|
||||
|
||||
# prompt_embeds must be (seq_len, hidden_size), but if the user
|
||||
# passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
|
||||
# we can unambiguously process the intent by squeezing the batch
|
||||
# dimension.
|
||||
if prompt_embeds.ndim == 3 and prompt_embeds.shape[0] == 1:
|
||||
prompt_embeds = prompt_embeds.squeeze(dim=0)
|
||||
|
||||
if prompt_embeds.ndim != 2:
|
||||
raise ValueError(
|
||||
"prompt_embeds must be of shape (seq_len, hidden_size).")
|
||||
|
||||
return embeds_inputs(prompt_embeds=prompt_embeds)
|
||||
|
||||
assert_never(parsed)
|
||||
|
||||
def _build_enc_dec_llm_inputs(
|
||||
self,
|
||||
encoder_inputs: Union[TokenInputs, MultiModalInputs],
|
||||
decoder_inputs: Optional[Union[TokenInputs, MultiModalInputs]],
|
||||
encoder_inputs: SingletonInputs,
|
||||
decoder_inputs: Optional[SingletonInputs],
|
||||
) -> EncoderDecoderInputs:
|
||||
if (encoder_inputs["type"] == "token"
|
||||
or encoder_inputs["type"] == "multimodal"):
|
||||
pass
|
||||
else:
|
||||
assert_never(encoder_inputs) # type: ignore[arg-type]
|
||||
if (encoder_inputs["type"] == "embeds"
|
||||
or decoder_inputs and decoder_inputs["type"] == "embeds"):
|
||||
raise ValueError("Embedding inputs are not supported for encoder-"
|
||||
"decoder models")
|
||||
|
||||
# Mypy does not correctly infer that EmbedsInputs is impossible
|
||||
assert "prompt_token_ids" in encoder_inputs
|
||||
# Needed for mypy
|
||||
encoder_inputs = cast(Union[TokenInputs, MultiModalInputs],
|
||||
encoder_inputs)
|
||||
decoder_inputs = cast(Optional[Union[TokenInputs, MultiModalInputs]],
|
||||
decoder_inputs)
|
||||
|
||||
if decoder_inputs is None:
|
||||
if self.model_config.hf_config.model_type == "whisper":
|
||||
@ -460,74 +563,78 @@ class InputPreprocessor:
|
||||
dec_token_ids = self._prepare_decoder_input_ids_for_generation(
|
||||
None)
|
||||
decoder_inputs = token_inputs(dec_token_ids)
|
||||
elif (decoder_inputs["type"] == "token"
|
||||
or decoder_inputs["type"] == "multimodal"):
|
||||
dec_token_ids = self._prepare_decoder_input_ids_for_generation(
|
||||
decoder_inputs["prompt_token_ids"])
|
||||
decoder_inputs["prompt_token_ids"] = dec_token_ids
|
||||
|
||||
else:
|
||||
if "multi_modal_data" in decoder_inputs:
|
||||
raise ValueError("Multi-modal decoder inputs of encoder-"
|
||||
"decoder models are not supported yet")
|
||||
else:
|
||||
assert_never(encoder_inputs) # type: ignore[arg-type]
|
||||
|
||||
dec_token_ids = self._prepare_decoder_input_ids_for_generation(
|
||||
decoder_inputs["prompt_token_ids"])
|
||||
decoder_inputs["prompt_token_ids"] = dec_token_ids
|
||||
|
||||
return EncoderDecoderInputs(
|
||||
encoder=encoder_inputs,
|
||||
decoder=decoder_inputs,
|
||||
)
|
||||
|
||||
def _separate_enc_dec_inputs_from_mm_processor_outputs(
|
||||
def _split_enc_dec_mm_inputs(
|
||||
self,
|
||||
inputs: SingletonInputs,
|
||||
decoder_inputs_to_override: Optional[Union[TokenInputs,
|
||||
MultiModalInputs]] = None,
|
||||
inputs: Union[SingletonInputs, MultiModalEncDecInputs],
|
||||
decoder_inputs_to_override: Optional[SingletonInputs] = None,
|
||||
) -> tuple[SingletonInputs, SingletonInputs]:
|
||||
"""
|
||||
For encoder/decoder models only:
|
||||
Separate Encoder/Decoder inputs from a MultiModalEncDecInputs
|
||||
"""
|
||||
if (inputs["type"] == "embeds" or decoder_inputs_to_override
|
||||
and decoder_inputs_to_override["type"] == "embeds"):
|
||||
raise ValueError("Embedding inputs are not supported for encoder-"
|
||||
"decoder models")
|
||||
|
||||
# Needed for mypy
|
||||
inputs = cast(
|
||||
Union[TokenInputs, MultiModalInputs, MultiModalEncDecInputs],
|
||||
inputs,
|
||||
)
|
||||
decoder_inputs_to_override = cast(
|
||||
Optional[Union[TokenInputs, MultiModalInputs]],
|
||||
decoder_inputs_to_override,
|
||||
)
|
||||
|
||||
encoder_inputs: SingletonInputs
|
||||
decoder_inputs: SingletonInputs
|
||||
if inputs["type"] == "multimodal":
|
||||
# Multimodal data inputs
|
||||
assert ("encoder_prompt" in inputs
|
||||
and "encoder_prompt_token_ids" in inputs)
|
||||
|
||||
if inputs["type"] == "multimodal": # Multimodal data inputs
|
||||
if not ("encoder_prompt" in inputs
|
||||
and "encoder_prompt_token_ids" in inputs):
|
||||
raise RuntimeError("You should register an encoder-decoder "
|
||||
"multi-modal processor for encoder-decoder "
|
||||
"models.")
|
||||
inputs = cast(MultiModalEncDecInputs, inputs)
|
||||
|
||||
encoder_inputs = token_inputs(
|
||||
prompt=inputs["encoder_prompt"],
|
||||
prompt_token_ids=inputs["encoder_prompt_token_ids"],
|
||||
)
|
||||
if decoder_inputs_to_override is not None:
|
||||
decoder_inputs = MultiModalInputs(
|
||||
type="multimodal",
|
||||
prompt=decoder_inputs_to_override.get("prompt", ""),
|
||||
prompt_token_ids=decoder_inputs_to_override[
|
||||
"prompt_token_ids"],
|
||||
mm_kwargs=inputs["mm_kwargs"],
|
||||
mm_hashes=inputs["mm_hashes"],
|
||||
mm_placeholders=inputs["mm_placeholders"],
|
||||
)
|
||||
else:
|
||||
decoder_inputs = MultiModalInputs(
|
||||
type="multimodal",
|
||||
prompt=inputs["prompt"],
|
||||
prompt_token_ids=inputs["prompt_token_ids"],
|
||||
mm_kwargs=inputs["mm_kwargs"],
|
||||
mm_hashes=inputs["mm_hashes"],
|
||||
mm_placeholders=inputs["mm_placeholders"],
|
||||
)
|
||||
|
||||
cache_salt = inputs.get("cache_salt")
|
||||
if cache_salt is not None:
|
||||
decoder_prompt_inputs = decoder_inputs_to_override or inputs
|
||||
decoder_inputs = MultiModalInputs(
|
||||
type="multimodal",
|
||||
prompt=decoder_prompt_inputs.get("prompt", ""),
|
||||
prompt_token_ids=decoder_prompt_inputs["prompt_token_ids"],
|
||||
mm_kwargs=inputs["mm_kwargs"],
|
||||
mm_hashes=inputs["mm_hashes"],
|
||||
mm_placeholders=inputs["mm_placeholders"],
|
||||
)
|
||||
if cache_salt := inputs.get("cache_salt"):
|
||||
decoder_inputs["cache_salt"] = cache_salt
|
||||
|
||||
elif inputs["type"] == "token":
|
||||
# Text-only inputs
|
||||
elif inputs["type"] == "token": # Text-only inputs
|
||||
encoder_inputs = token_inputs(prompt="", prompt_token_ids=[])
|
||||
decoder_inputs = decoder_inputs_to_override or inputs
|
||||
else:
|
||||
assert_never(inputs) # type: ignore[arg-type]
|
||||
|
||||
return encoder_inputs, decoder_inputs
|
||||
|
||||
def _process_encoder_decoder_prompt(
|
||||
@ -580,11 +687,9 @@ class InputPreprocessor:
|
||||
# For multimodal model, override decoder prompt from processor
|
||||
# with explicit decoder prompt.
|
||||
if self.model_config.is_multimodal_model:
|
||||
assert decoder_inputs is None or not is_embeds_inputs(
|
||||
decoder_inputs)
|
||||
encoder_inputs, decoder_inputs = (
|
||||
self._separate_enc_dec_inputs_from_mm_processor_outputs(
|
||||
encoder_inputs, decoder_inputs))
|
||||
self._split_enc_dec_mm_inputs(encoder_inputs,
|
||||
decoder_inputs))
|
||||
else:
|
||||
inputs = self._prompt_to_llm_inputs(
|
||||
prompt,
|
||||
@ -593,16 +698,11 @@ class InputPreprocessor:
|
||||
if self.model_config.is_multimodal_model:
|
||||
# Encoder-Decoder Multimodal model
|
||||
encoder_inputs, decoder_inputs = (
|
||||
self._separate_enc_dec_inputs_from_mm_processor_outputs(
|
||||
inputs))
|
||||
self._split_enc_dec_mm_inputs(inputs))
|
||||
else:
|
||||
encoder_inputs = inputs
|
||||
decoder_inputs = None
|
||||
|
||||
# Mypy does not do type inference well with TypedDicts with Literal
|
||||
# values.
|
||||
assert not is_embeds_inputs(encoder_inputs)
|
||||
assert decoder_inputs is None or not is_embeds_inputs(decoder_inputs)
|
||||
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
|
||||
|
||||
async def _process_encoder_decoder_prompt_async(
|
||||
@ -635,11 +735,9 @@ class InputPreprocessor:
|
||||
# For multimodal model, override decoder prompt from processor
|
||||
# with explicit decoder prompt.
|
||||
if self.model_config.is_multimodal_model:
|
||||
assert decoder_inputs is None or not is_embeds_inputs(
|
||||
decoder_inputs)
|
||||
encoder_inputs, decoder_inputs = (
|
||||
self._separate_enc_dec_inputs_from_mm_processor_outputs(
|
||||
encoder_inputs, decoder_inputs))
|
||||
self._split_enc_dec_mm_inputs(encoder_inputs,
|
||||
decoder_inputs))
|
||||
else:
|
||||
inputs = await self._prompt_to_llm_inputs_async(
|
||||
prompt,
|
||||
@ -648,16 +746,11 @@ class InputPreprocessor:
|
||||
if self.model_config.is_multimodal_model:
|
||||
# Encoder-Decoder Multimodal model
|
||||
encoder_inputs, decoder_inputs = (
|
||||
self._separate_enc_dec_inputs_from_mm_processor_outputs(
|
||||
inputs))
|
||||
self._split_enc_dec_mm_inputs(inputs))
|
||||
else:
|
||||
encoder_inputs = inputs
|
||||
decoder_inputs = None
|
||||
|
||||
# Mypy does not do type inference well with TypedDicts with Literal
|
||||
# values.
|
||||
assert not is_embeds_inputs(encoder_inputs)
|
||||
assert decoder_inputs is None or not is_embeds_inputs(decoder_inputs)
|
||||
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
|
||||
|
||||
def _build_decoder_only_llm_inputs(
|
||||
@ -665,19 +758,13 @@ class InputPreprocessor:
|
||||
prompt_inputs: DecoderOnlyInputs,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
) -> DecoderOnlyInputs:
|
||||
if (prompt_inputs["type"] == "token"
|
||||
or prompt_inputs["type"] == "multimodal"):
|
||||
# Mypy does not do type inference well with typedicts and Literal
|
||||
# values
|
||||
assert not is_embeds_inputs(prompt_inputs)
|
||||
if "prompt_token_ids" in prompt_inputs:
|
||||
prompt_inputs = cast(Union[TokenInputs, MultiModalInputs],
|
||||
prompt_inputs) # Needed for mypy
|
||||
prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter(
|
||||
prompt_inputs["prompt_token_ids"],
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
elif (prompt_inputs["type"] == "embeds"):
|
||||
pass
|
||||
else:
|
||||
assert_never(prompt_inputs) # type: ignore[arg-type]
|
||||
|
||||
return prompt_inputs
|
||||
|
||||
|
||||
@ -1670,15 +1670,17 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
placeholders = mm_placeholders.get(modality, [])
|
||||
|
||||
if len(placeholders) != item_count:
|
||||
# NOTE: If you are a model developer, this can also arise from
|
||||
# an inconsistency between `_call_hf_processor` and
|
||||
# `_get_mm_fields_config` implementations
|
||||
raise RuntimeError(
|
||||
f"Expected there to be {item_count} prompt updates "
|
||||
f"corresponding to {item_count} {modality} items, but "
|
||||
f"instead found {len(placeholders)} prompt updates! "
|
||||
"Either the prompt text has missing/incorrect tokens for "
|
||||
"multi-modal inputs, or there is a problem with your "
|
||||
"implementation of merged multi-modal processor for this "
|
||||
"model (usually arising from an inconsistency between "
|
||||
"`_call_hf_processor` and `_get_prompt_updates`).")
|
||||
"This is likely because you forgot to include input "
|
||||
"placeholder tokens (e.g., `<image>`, `<|image_pad|>`) "
|
||||
"in the prompt. If the model has a chat template, make "
|
||||
"sure you have applied it before calling `LLM.generate`.")
|
||||
|
||||
def _maybe_apply_prompt_updates(
|
||||
self,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user