mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-03 04:04:32 +08:00
[Misc] Clean up input processing (#17582)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
3a500cd0b6
commit
cb234955df
@ -6,6 +6,7 @@ from huggingface_hub import snapshot_download
|
|||||||
from transformers import AutoConfig, AutoModel, CLIPImageProcessor
|
from transformers import AutoConfig, AutoModel, CLIPImageProcessor
|
||||||
|
|
||||||
from vllm.distributed import cleanup_dist_env_and_memory
|
from vllm.distributed import cleanup_dist_env_and_memory
|
||||||
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||||
|
|
||||||
from ....conftest import ImageTestAssets
|
from ....conftest import ImageTestAssets
|
||||||
|
|
||||||
@ -14,6 +15,7 @@ from ....conftest import ImageTestAssets
|
|||||||
DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"]
|
DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"]
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
def run_intern_vit_test(
|
def run_intern_vit_test(
|
||||||
image_assets: ImageTestAssets,
|
image_assets: ImageTestAssets,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
@ -21,11 +23,12 @@ def run_intern_vit_test(
|
|||||||
dtype: str,
|
dtype: str,
|
||||||
):
|
):
|
||||||
model = snapshot_download(model_id, allow_patterns=DOWNLOAD_PATTERN)
|
model = snapshot_download(model_id, allow_patterns=DOWNLOAD_PATTERN)
|
||||||
|
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||||
|
|
||||||
img_processor = CLIPImageProcessor.from_pretrained(model)
|
img_processor = CLIPImageProcessor.from_pretrained(model)
|
||||||
images = [asset.pil_image for asset in image_assets]
|
images = [asset.pil_image for asset in image_assets]
|
||||||
pixel_values = [
|
pixel_values = [
|
||||||
img_processor(images, return_tensors='pt').pixel_values.to(dtype)
|
img_processor(images, return_tensors='pt').pixel_values.to(torch_dtype)
|
||||||
for images in images
|
for images in images
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -34,7 +37,7 @@ def run_intern_vit_test(
|
|||||||
config.norm_type = "rms_norm"
|
config.norm_type = "rms_norm"
|
||||||
|
|
||||||
hf_model = AutoModel.from_pretrained(model,
|
hf_model = AutoModel.from_pretrained(model,
|
||||||
torch_dtype=dtype,
|
torch_dtype=torch_dtype,
|
||||||
trust_remote_code=True).to("cuda")
|
trust_remote_code=True).to("cuda")
|
||||||
hf_outputs_per_image = [
|
hf_outputs_per_image = [
|
||||||
hf_model(pixel_value.to("cuda")).last_hidden_state
|
hf_model(pixel_value.to("cuda")).last_hidden_state
|
||||||
@ -48,7 +51,7 @@ def run_intern_vit_test(
|
|||||||
del hf_model
|
del hf_model
|
||||||
cleanup_dist_env_and_memory()
|
cleanup_dist_env_and_memory()
|
||||||
|
|
||||||
vllm_model = vllm_model.to("cuda", dtype)
|
vllm_model = vllm_model.to("cuda", torch_dtype)
|
||||||
vllm_outputs_per_image = [
|
vllm_outputs_per_image = [
|
||||||
vllm_model(pixel_values=pixel_value.to("cuda"))
|
vllm_model(pixel_values=pixel_value.to("cuda"))
|
||||||
for pixel_value in pixel_values
|
for pixel_value in pixel_values
|
||||||
@ -66,9 +69,8 @@ def run_intern_vit_test(
|
|||||||
"OpenGVLab/InternViT-300M-448px",
|
"OpenGVLab/InternViT-300M-448px",
|
||||||
"OpenGVLab/InternViT-6B-448px-V1-5",
|
"OpenGVLab/InternViT-6B-448px-V1-5",
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("dtype", [torch.half])
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
@torch.inference_mode()
|
def test_models(dist_init, image_assets, model_id, dtype: str) -> None:
|
||||||
def test_models(image_assets, model_id, dtype: str) -> None:
|
|
||||||
run_intern_vit_test(
|
run_intern_vit_test(
|
||||||
image_assets,
|
image_assets,
|
||||||
model_id,
|
model_id,
|
||||||
|
|||||||
@ -497,10 +497,6 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
prompt["prompt_token_ids"] = [0
|
prompt["prompt_token_ids"] = [0
|
||||||
] * prompt["prompt_embeds"].shape[-2]
|
] * prompt["prompt_embeds"].shape[-2]
|
||||||
|
|
||||||
if self.tokenizer is not None:
|
|
||||||
tokenizer = await self.get_tokenizer_async(lora_request)
|
|
||||||
self._validate_token_prompt(prompt, tokenizer=tokenizer)
|
|
||||||
|
|
||||||
processed_inputs = await self.input_preprocessor.preprocess_async(
|
processed_inputs = await self.input_preprocessor.preprocess_async(
|
||||||
prompt,
|
prompt,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
|
|||||||
@ -30,7 +30,7 @@ from vllm.entrypoints.openai.logits_processors import (
|
|||||||
get_logits_processors as get_openai_logits_processors)
|
get_logits_processors as get_openai_logits_processors)
|
||||||
from vllm.executor.executor_base import ExecutorBase
|
from vllm.executor.executor_base import ExecutorBase
|
||||||
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
|
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
|
||||||
from vllm.inputs.parse import is_token_prompt, split_enc_dec_inputs
|
from vllm.inputs.parse import split_enc_dec_inputs
|
||||||
from vllm.inputs.preprocess import InputPreprocessor
|
from vllm.inputs.preprocess import InputPreprocessor
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.logits_process import get_bad_words_logits_processors
|
from vllm.logits_process import get_bad_words_logits_processors
|
||||||
@ -759,11 +759,6 @@ class LLMEngine:
|
|||||||
seq_len = prompt["prompt_embeds"].shape[0]
|
seq_len = prompt["prompt_embeds"].shape[0]
|
||||||
prompt["prompt_token_ids"] = [0] * seq_len
|
prompt["prompt_token_ids"] = [0] * seq_len
|
||||||
|
|
||||||
if self.tokenizer is not None:
|
|
||||||
self._validate_token_prompt(
|
|
||||||
prompt,
|
|
||||||
tokenizer=self.get_tokenizer(lora_request=lora_request))
|
|
||||||
|
|
||||||
processed_inputs = self.input_preprocessor.preprocess(
|
processed_inputs = self.input_preprocessor.preprocess(
|
||||||
prompt,
|
prompt,
|
||||||
tokenization_kwargs=tokenization_kwargs,
|
tokenization_kwargs=tokenization_kwargs,
|
||||||
@ -782,27 +777,6 @@ class LLMEngine:
|
|||||||
priority=priority,
|
priority=priority,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _validate_token_prompt(self, prompt: PromptType,
|
|
||||||
tokenizer: AnyTokenizer):
|
|
||||||
# Guard against out-of-vocab tokens.
|
|
||||||
# For some tokenizers, tokenizer.decode will happily return empty text
|
|
||||||
# for token ids that are out of vocab, and we don't detect token ids
|
|
||||||
# that are greater than the max token id before running the model.
|
|
||||||
# However, these token ids will later crash a cuda kernel at runtime
|
|
||||||
# with an index out of bounds error. This will crash the entire engine.
|
|
||||||
# This needs to happen before multimodal input pre-processing, which
|
|
||||||
# may add dummy <image> tokens that aren't part of the tokenizer's
|
|
||||||
# vocabulary.
|
|
||||||
if is_token_prompt(prompt):
|
|
||||||
prompt_ids = prompt["prompt_token_ids"]
|
|
||||||
if len(prompt_ids) == 0:
|
|
||||||
# Empty prompt check is handled later
|
|
||||||
return
|
|
||||||
max_input_id = max(prompt_ids)
|
|
||||||
if max_input_id > tokenizer.max_token_id:
|
|
||||||
raise ValueError(
|
|
||||||
"Token id {} is out of vocabulary".format(max_input_id))
|
|
||||||
|
|
||||||
def _create_sequence_group_with_sampling(
|
def _create_sequence_group_with_sampling(
|
||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
@ -2049,6 +2023,12 @@ class LLMEngine:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"The {prompt_type} prompt cannot be empty")
|
raise ValueError(f"The {prompt_type} prompt cannot be empty")
|
||||||
|
|
||||||
|
if tokenizer is not None:
|
||||||
|
max_input_id = max(prompt_ids, default=0)
|
||||||
|
if max_input_id > tokenizer.max_token_id:
|
||||||
|
raise ValueError(
|
||||||
|
f"Token id {max_input_id} is out of vocabulary")
|
||||||
|
|
||||||
max_prompt_len = self.model_config.max_model_len
|
max_prompt_len = self.model_config.max_model_len
|
||||||
if len(prompt_ids) > max_prompt_len:
|
if len(prompt_ids) > max_prompt_len:
|
||||||
if prompt_type == "encoder" and model_config.is_multimodal_model:
|
if prompt_type == "encoder" and model_config.is_multimodal_model:
|
||||||
|
|||||||
@ -83,6 +83,9 @@ class EngineClient(ABC):
|
|||||||
else:
|
else:
|
||||||
processed_inputs = preprocessor._prompt_to_llm_inputs(prompt)
|
processed_inputs = preprocessor._prompt_to_llm_inputs(prompt)
|
||||||
|
|
||||||
|
if processed_inputs["type"] == "embeds":
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
prompt_token_ids = processed_inputs["prompt_token_ids"]
|
prompt_token_ids = processed_inputs["prompt_token_ids"]
|
||||||
prompt_text = processed_inputs.get("prompt")
|
prompt_text = processed_inputs.get("prompt")
|
||||||
multi_modal_data = processed_inputs.get("multi_modal_data")
|
multi_modal_data = processed_inputs.get("multi_modal_data")
|
||||||
|
|||||||
@ -27,7 +27,7 @@ from vllm.entrypoints.score_utils import (_cosine_similarity,
|
|||||||
_validate_score_input_lens)
|
_validate_score_input_lens)
|
||||||
from vllm.entrypoints.utils import _validate_truncation_size
|
from vllm.entrypoints.utils import _validate_truncation_size
|
||||||
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
|
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
|
||||||
from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt
|
from vllm.inputs.parse import parse_and_batch_prompt
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.model_executor.guided_decoding.guided_fields import (
|
from vllm.model_executor.guided_decoding.guided_fields import (
|
||||||
@ -567,10 +567,12 @@ class LLM:
|
|||||||
mm_kwargs["mm_processor_kwargs"] = prompt[
|
mm_kwargs["mm_processor_kwargs"] = prompt[
|
||||||
"mm_processor_kwargs"]
|
"mm_processor_kwargs"]
|
||||||
|
|
||||||
if is_token_prompt(prompt):
|
if "prompt_token_ids" in prompt:
|
||||||
|
prompt = cast(TokensPrompt, prompt) # Needed for mypy
|
||||||
prompt_tokens = prompt["prompt_token_ids"]
|
prompt_tokens = prompt["prompt_token_ids"]
|
||||||
else:
|
else:
|
||||||
prompt_tokens = tokenizer.encode(prompt["prompt"])
|
prompt_tokens = tokenizer.encode(prompt["prompt"])
|
||||||
|
|
||||||
instances.append(
|
instances.append(
|
||||||
BeamSearchInstance(prompt_tokens, logprobs=None, **mm_kwargs))
|
BeamSearchInstance(prompt_tokens, logprobs=None, **mm_kwargs))
|
||||||
|
|
||||||
|
|||||||
@ -70,6 +70,11 @@ class EmbedsPrompt(TypedDict):
|
|||||||
prompt_embeds: torch.Tensor
|
prompt_embeds: torch.Tensor
|
||||||
"""The embeddings of the prompt."""
|
"""The embeddings of the prompt."""
|
||||||
|
|
||||||
|
cache_salt: NotRequired[str]
|
||||||
|
"""
|
||||||
|
Optional cache salt to be used for prefix caching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt]
|
SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt]
|
||||||
"""
|
"""
|
||||||
@ -195,13 +200,21 @@ class EmbedsInputs(TypedDict):
|
|||||||
prompt_embeds: torch.Tensor
|
prompt_embeds: torch.Tensor
|
||||||
"""The embeddings of the prompt."""
|
"""The embeddings of the prompt."""
|
||||||
|
|
||||||
|
cache_salt: NotRequired[str]
|
||||||
|
"""
|
||||||
|
Optional cache salt to be used for prefix caching.
|
||||||
|
"""
|
||||||
|
|
||||||
def embeds_inputs(prompt_embeds: torch.Tensor) -> EmbedsInputs:
|
|
||||||
|
def embeds_inputs(
|
||||||
|
prompt_embeds: torch.Tensor,
|
||||||
|
cache_salt: Optional[str] = None,
|
||||||
|
) -> EmbedsInputs:
|
||||||
"""Construct :class:`EmbedsInputs` from optional values."""
|
"""Construct :class:`EmbedsInputs` from optional values."""
|
||||||
inputs = EmbedsInputs(
|
inputs = EmbedsInputs(type="embeds", prompt_embeds=prompt_embeds)
|
||||||
type="embeds",
|
|
||||||
prompt_embeds=prompt_embeds,
|
if cache_salt is not None:
|
||||||
)
|
inputs["cache_salt"] = cache_salt
|
||||||
|
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
|
|||||||
@ -6,9 +6,9 @@ from typing_extensions import TypeIs
|
|||||||
|
|
||||||
from vllm.utils import is_list_of
|
from vllm.utils import is_list_of
|
||||||
|
|
||||||
from .data import (EmbedsInputs, EmbedsPrompt, ExplicitEncoderDecoderPrompt,
|
from .data import (EmbedsPrompt, ExplicitEncoderDecoderPrompt, ProcessorInputs,
|
||||||
ProcessorInputs, PromptType, SingletonInputs,
|
PromptType, SingletonInputs, SingletonPrompt, TextPrompt,
|
||||||
SingletonPrompt, TextPrompt, TokensPrompt)
|
TokensPrompt)
|
||||||
|
|
||||||
|
|
||||||
class ParsedText(TypedDict):
|
class ParsedText(TypedDict):
|
||||||
@ -90,6 +90,10 @@ class ParsedEmbedsPrompt(TypedDict):
|
|||||||
content: EmbedsPrompt
|
content: EmbedsPrompt
|
||||||
|
|
||||||
|
|
||||||
|
ParsedSingletonPrompt = Union[ParsedStrPrompt, ParsedTextPrompt,
|
||||||
|
ParsedTokensPrompt, ParsedEmbedsPrompt]
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def parse_singleton_prompt(prompt: str) -> ParsedStrPrompt:
|
def parse_singleton_prompt(prompt: str) -> ParsedStrPrompt:
|
||||||
...
|
...
|
||||||
@ -110,10 +114,7 @@ def parse_singleton_prompt(prompt: EmbedsPrompt) -> ParsedEmbedsPrompt:
|
|||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
def parse_singleton_prompt(
|
def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt:
|
||||||
prompt: SingletonPrompt,
|
|
||||||
) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt,
|
|
||||||
ParsedEmbedsPrompt]:
|
|
||||||
if isinstance(prompt, str):
|
if isinstance(prompt, str):
|
||||||
return ParsedStrPrompt(type="str", content=prompt)
|
return ParsedStrPrompt(type="str", content=prompt)
|
||||||
elif isinstance(prompt, dict):
|
elif isinstance(prompt, dict):
|
||||||
@ -131,23 +132,11 @@ def parse_singleton_prompt(
|
|||||||
"inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt")
|
"inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt")
|
||||||
|
|
||||||
|
|
||||||
def is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]:
|
|
||||||
return isinstance(prompt, dict) and "prompt_token_ids" in prompt
|
|
||||||
|
|
||||||
|
|
||||||
def is_embeds_prompt(prompt: PromptType) -> TypeIs[EmbedsPrompt]:
|
|
||||||
return isinstance(prompt, dict) and "prompt_embeds" in prompt
|
|
||||||
|
|
||||||
|
|
||||||
def is_explicit_encoder_decoder_prompt(
|
def is_explicit_encoder_decoder_prompt(
|
||||||
prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]:
|
prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]:
|
||||||
return isinstance(prompt, dict) and "encoder_prompt" in prompt
|
return isinstance(prompt, dict) and "encoder_prompt" in prompt
|
||||||
|
|
||||||
|
|
||||||
def is_embeds_inputs(inputs: SingletonInputs) -> TypeIs[EmbedsInputs]:
|
|
||||||
return isinstance(inputs, dict) and inputs["type"] == "embeds"
|
|
||||||
|
|
||||||
|
|
||||||
def split_enc_dec_inputs(
|
def split_enc_dec_inputs(
|
||||||
inputs: ProcessorInputs,
|
inputs: ProcessorInputs,
|
||||||
) -> tuple[Optional[SingletonInputs], SingletonInputs]:
|
) -> tuple[Optional[SingletonInputs], SingletonInputs]:
|
||||||
|
|||||||
@ -14,14 +14,14 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
|||||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
|
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
|
||||||
MultiModalInputs)
|
MultiModalInputs)
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||||
|
|
||||||
from .data import (DecoderOnlyInputs, EmbedsInputs, EncoderDecoderInputs,
|
from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
|
||||||
ProcessorInputs, PromptType, SingletonInputs,
|
EncoderDecoderInputs, ProcessorInputs, PromptType,
|
||||||
SingletonPrompt, TokenInputs, embeds_inputs, token_inputs)
|
SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs,
|
||||||
from .parse import (ParsedEmbedsPrompt, ParsedStrPrompt, ParsedTextPrompt,
|
TokensPrompt, embeds_inputs, token_inputs)
|
||||||
ParsedTokensPrompt, is_embeds_inputs,
|
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
|
||||||
is_explicit_encoder_decoder_prompt, parse_singleton_prompt)
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -140,13 +140,10 @@ class InputPreprocessor:
|
|||||||
"""
|
"""
|
||||||
Prepares `decoder_input_ids` for generation with encoder-decoder models.
|
Prepares `decoder_input_ids` for generation with encoder-decoder models.
|
||||||
|
|
||||||
Based on
|
Based on:
|
||||||
|
https://github.com/huggingface/transformers/blob/4037a2b5b1278736e566aec12e169100275545ea/src/transformers/generation/utils.py
|
||||||
https://github.com/huggingface/transformers/blob/
|
specifically,
|
||||||
4037a2b5b1278736e566aec12e169100275545ea/
|
`GenerationMixin._prepare_decoder_input_ids_for_generation()`.
|
||||||
src/transformers/generation/utils.py
|
|
||||||
|
|
||||||
specifically GenerationMixin._prepare_decoder_input_ids_for_generation()
|
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
|
|
||||||
@ -183,6 +180,23 @@ class InputPreprocessor:
|
|||||||
|
|
||||||
return prompt_token_ids
|
return prompt_token_ids
|
||||||
|
|
||||||
|
def _get_tokenization_kw(
|
||||||
|
self,
|
||||||
|
overrides: Optional[dict[str, Any]] = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
kwargs = dict[str, Any]()
|
||||||
|
|
||||||
|
if self.model_config.hf_config.model_type == "whisper":
|
||||||
|
# For Whisper, special tokens should be provided by the user based
|
||||||
|
# on the task and language of their request. Also needed to avoid
|
||||||
|
# appending an EOS token to the prompt which disrupts generation.
|
||||||
|
kwargs["add_special_tokens"] = False
|
||||||
|
|
||||||
|
if overrides:
|
||||||
|
kwargs.update(overrides)
|
||||||
|
|
||||||
|
return kwargs
|
||||||
|
|
||||||
def _tokenize_prompt(
|
def _tokenize_prompt(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@ -194,18 +208,11 @@ class InputPreprocessor:
|
|||||||
corresponding token IDs.
|
corresponding token IDs.
|
||||||
"""
|
"""
|
||||||
tokenizer = self.get_tokenizer_group()
|
tokenizer = self.get_tokenizer_group()
|
||||||
if tokenization_kwargs is None:
|
tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
|
||||||
tokenization_kwargs = {}
|
|
||||||
|
|
||||||
if self.model_config.hf_config.model_type == "whisper":
|
encoder_config = self.model_config.encoder_config
|
||||||
# For Whisper, special tokens should be provided by the user based
|
|
||||||
# on the task and language of their request. Also needed to avoid
|
|
||||||
# appending an EOS token to the prompt which disrupts generation.
|
|
||||||
tokenization_kwargs["add_special_tokens"] = False
|
|
||||||
|
|
||||||
if (self.model_config.encoder_config is not None
|
if encoder_config and encoder_config.get("do_lower_case", False):
|
||||||
and self.model_config.encoder_config.get(
|
|
||||||
"do_lower_case", False)):
|
|
||||||
prompt = prompt.lower()
|
prompt = prompt.lower()
|
||||||
|
|
||||||
return tokenizer.encode(prompt=prompt,
|
return tokenizer.encode(prompt=prompt,
|
||||||
@ -220,18 +227,36 @@ class InputPreprocessor:
|
|||||||
) -> list[int]:
|
) -> list[int]:
|
||||||
"""Async version of :meth:`_tokenize_prompt`."""
|
"""Async version of :meth:`_tokenize_prompt`."""
|
||||||
tokenizer = self.get_tokenizer_group()
|
tokenizer = self.get_tokenizer_group()
|
||||||
if tokenization_kwargs is None:
|
tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
|
||||||
tokenization_kwargs = {}
|
|
||||||
|
|
||||||
if self.model_config.hf_config.model_type == "whisper":
|
|
||||||
# For Whisper, special tokens should be provided by the user based
|
|
||||||
# on the task and language of their request. Also needed to avoid
|
|
||||||
# appending an EOS token to the prompt which disrupts generation.
|
|
||||||
tokenization_kwargs["add_special_tokens"] = False
|
|
||||||
return await tokenizer.encode_async(prompt=prompt,
|
return await tokenizer.encode_async(prompt=prompt,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
**tokenization_kwargs)
|
**tokenization_kwargs)
|
||||||
|
|
||||||
|
def _get_mm_tokenizer(
|
||||||
|
self,
|
||||||
|
lora_request: Optional[LoRARequest],
|
||||||
|
) -> AnyTokenizer:
|
||||||
|
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer
|
||||||
|
# while using also multi-modal input
|
||||||
|
if not self.tokenizer:
|
||||||
|
return cast(AnyTokenizer, object()) # Dummy
|
||||||
|
|
||||||
|
tokenizer_group = self.get_tokenizer_group()
|
||||||
|
return tokenizer_group.get_lora_tokenizer(lora_request)
|
||||||
|
|
||||||
|
async def _get_mm_tokenizer_async(
|
||||||
|
self,
|
||||||
|
lora_request: Optional[LoRARequest],
|
||||||
|
) -> AnyTokenizer:
|
||||||
|
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer
|
||||||
|
# while using also multi-modal input
|
||||||
|
if not self.tokenizer:
|
||||||
|
return cast(AnyTokenizer, object()) # Dummy
|
||||||
|
|
||||||
|
tokenizer_group = self.get_tokenizer_group()
|
||||||
|
return await tokenizer_group.get_lora_tokenizer_async(lora_request)
|
||||||
|
|
||||||
def _process_multimodal(
|
def _process_multimodal(
|
||||||
self,
|
self,
|
||||||
prompt: Union[str, list[int]],
|
prompt: Union[str, list[int]],
|
||||||
@ -244,13 +269,7 @@ class InputPreprocessor:
|
|||||||
Apply the model's multi-modal processor to a multi-modal prompt,
|
Apply the model's multi-modal processor to a multi-modal prompt,
|
||||||
returning the corresponding token IDs and metadata.
|
returning the corresponding token IDs and metadata.
|
||||||
"""
|
"""
|
||||||
# At the moment on model (PrithviGeoSpatialMAE) requires to be
|
tokenizer = self._get_mm_tokenizer(lora_request)
|
||||||
# initialized without a tokenizer while using also multi-modal input
|
|
||||||
if not self.tokenizer:
|
|
||||||
tokenizer = object() # Dummy
|
|
||||||
else:
|
|
||||||
tokenizer_group = self.get_tokenizer_group()
|
|
||||||
tokenizer = tokenizer_group.get_lora_tokenizer(lora_request)
|
|
||||||
|
|
||||||
mm_processor = self.mm_registry.create_processor(self.model_config,
|
mm_processor = self.mm_registry.create_processor(self.model_config,
|
||||||
tokenizer=tokenizer)
|
tokenizer=tokenizer)
|
||||||
@ -270,14 +289,7 @@ class InputPreprocessor:
|
|||||||
return_mm_hashes: bool = False,
|
return_mm_hashes: bool = False,
|
||||||
) -> MultiModalInputs:
|
) -> MultiModalInputs:
|
||||||
"""Async version of :meth:`_process_multimodal`."""
|
"""Async version of :meth:`_process_multimodal`."""
|
||||||
# At the moment on model (PrithviGeoSpatialMAE) requires to be
|
tokenizer = await self._get_mm_tokenizer_async(lora_request)
|
||||||
# initialized without a tokenizer while using also multi-modal input
|
|
||||||
if not self.tokenizer:
|
|
||||||
tokenizer = object() # Dummy
|
|
||||||
else:
|
|
||||||
tokenizer_group = self.get_tokenizer_group()
|
|
||||||
tokenizer = await tokenizer_group.get_lora_tokenizer_async(
|
|
||||||
lora_request)
|
|
||||||
|
|
||||||
mm_processor = self.mm_registry.create_processor(self.model_config,
|
mm_processor = self.mm_registry.create_processor(self.model_config,
|
||||||
tokenizer=tokenizer)
|
tokenizer=tokenizer)
|
||||||
@ -287,28 +299,160 @@ class InputPreprocessor:
|
|||||||
return mm_processor.apply(prompt, mm_data, mm_processor_kwargs,
|
return mm_processor.apply(prompt, mm_data, mm_processor_kwargs,
|
||||||
return_mm_hashes)
|
return_mm_hashes)
|
||||||
|
|
||||||
def _get_prompt_data(self, parsed_prompt: Union[ParsedStrPrompt,
|
def _process_embeds(
|
||||||
ParsedTextPrompt,
|
self,
|
||||||
ParsedTokensPrompt]):
|
parsed_content: EmbedsPrompt,
|
||||||
prompt_text = None
|
) -> EmbedsInputs:
|
||||||
prompt_token_ids = None
|
if envs.VLLM_USE_V1:
|
||||||
token_type_ids = None
|
raise ValueError("prompt_embeds is only available in V0.")
|
||||||
cache_salt = None
|
|
||||||
|
|
||||||
if parsed_prompt["type"] == "str":
|
prompt_embeds = parsed_content["prompt_embeds"]
|
||||||
prompt_text = parsed_prompt["content"]
|
|
||||||
|
# prompt_embeds must be (seq_len, hidden_size), but if the user
|
||||||
|
# passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
|
||||||
|
# we can unambiguously process the intent by squeezing the batch
|
||||||
|
# dimension.
|
||||||
|
if prompt_embeds.ndim == 3:
|
||||||
|
prompt_embeds = prompt_embeds.squeeze(dim=0)
|
||||||
|
|
||||||
|
if prompt_embeds.ndim != 2:
|
||||||
|
raise ValueError(
|
||||||
|
"prompt_embeds must be of shape (seq_len, hidden_size).")
|
||||||
|
|
||||||
|
return embeds_inputs(prompt_embeds=prompt_embeds,
|
||||||
|
cache_salt=parsed_content.get("cache_salt"))
|
||||||
|
|
||||||
|
async def _process_embeds_async(
|
||||||
|
self,
|
||||||
|
parsed_content: EmbedsPrompt,
|
||||||
|
) -> EmbedsInputs:
|
||||||
|
return self._process_embeds(parsed_content)
|
||||||
|
|
||||||
|
def _process_tokens(
|
||||||
|
self,
|
||||||
|
parsed_content: TokensPrompt,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
return_mm_hashes: bool = False,
|
||||||
|
) -> Union[TokenInputs, MultiModalInputs]:
|
||||||
|
prompt_token_ids = parsed_content["prompt_token_ids"]
|
||||||
|
token_type_ids = parsed_content.get("token_type_ids")
|
||||||
|
|
||||||
|
inputs: Union[TokenInputs, MultiModalInputs]
|
||||||
|
if multi_modal_data := parsed_content.get("multi_modal_data"):
|
||||||
|
inputs = self._process_multimodal(
|
||||||
|
prompt_token_ids,
|
||||||
|
multi_modal_data,
|
||||||
|
parsed_content.get("mm_processor_kwargs"),
|
||||||
|
lora_request=lora_request,
|
||||||
|
return_mm_hashes=return_mm_hashes,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
cache_salt = parsed_prompt["content"].get("cache_salt")
|
inputs = token_inputs(
|
||||||
if parsed_prompt["type"] == "text":
|
prompt_token_ids=prompt_token_ids,
|
||||||
prompt_text = parsed_prompt["content"]["prompt"]
|
token_type_ids=token_type_ids,
|
||||||
elif parsed_prompt["type"] == "tokens":
|
)
|
||||||
prompt_token_ids = parsed_prompt["content"].get(
|
|
||||||
"prompt_token_ids")
|
|
||||||
token_type_ids = parsed_prompt["content"].get("token_type_ids")
|
|
||||||
else:
|
|
||||||
assert_never(parsed_prompt)
|
|
||||||
|
|
||||||
return prompt_text, prompt_token_ids, token_type_ids, cache_salt
|
if cache_salt := parsed_content.get("cache_salt"):
|
||||||
|
inputs["cache_salt"] = cache_salt
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
async def _process_tokens_async(
|
||||||
|
self,
|
||||||
|
parsed_content: TokensPrompt,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
return_mm_hashes: bool = False,
|
||||||
|
) -> Union[TokenInputs, MultiModalInputs]:
|
||||||
|
prompt_token_ids = parsed_content["prompt_token_ids"]
|
||||||
|
token_type_ids = parsed_content.get("token_type_ids")
|
||||||
|
|
||||||
|
inputs: Union[TokenInputs, MultiModalInputs]
|
||||||
|
if multi_modal_data := parsed_content.get("multi_modal_data"):
|
||||||
|
inputs = await self._process_multimodal_async(
|
||||||
|
prompt_token_ids,
|
||||||
|
multi_modal_data,
|
||||||
|
parsed_content.get("mm_processor_kwargs"),
|
||||||
|
lora_request=lora_request,
|
||||||
|
return_mm_hashes=return_mm_hashes,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
inputs = token_inputs(
|
||||||
|
prompt_token_ids=prompt_token_ids,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cache_salt := parsed_content.get("cache_salt"):
|
||||||
|
inputs["cache_salt"] = cache_salt
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def _process_text(
|
||||||
|
self,
|
||||||
|
parsed_content: TextPrompt,
|
||||||
|
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
return_mm_hashes: bool = False,
|
||||||
|
) -> Union[TokenInputs, MultiModalInputs]:
|
||||||
|
prompt_text = parsed_content["prompt"]
|
||||||
|
|
||||||
|
inputs: Union[TokenInputs, MultiModalInputs]
|
||||||
|
if multi_modal_data := parsed_content.get("multi_modal_data"):
|
||||||
|
inputs = self._process_multimodal(
|
||||||
|
prompt_text,
|
||||||
|
multi_modal_data,
|
||||||
|
parsed_content.get("mm_processor_kwargs"),
|
||||||
|
lora_request=lora_request,
|
||||||
|
return_mm_hashes=return_mm_hashes,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt_token_ids = self._tokenize_prompt(
|
||||||
|
prompt_text,
|
||||||
|
lora_request=lora_request,
|
||||||
|
tokenization_kwargs=tokenization_kwargs,
|
||||||
|
)
|
||||||
|
inputs = token_inputs(
|
||||||
|
prompt=prompt_text,
|
||||||
|
prompt_token_ids=prompt_token_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cache_salt := parsed_content.get("cache_salt"):
|
||||||
|
inputs["cache_salt"] = cache_salt
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
async def _process_text_async(
|
||||||
|
self,
|
||||||
|
parsed_content: TextPrompt,
|
||||||
|
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
return_mm_hashes: bool = False,
|
||||||
|
) -> Union[TokenInputs, MultiModalInputs]:
|
||||||
|
prompt_text = parsed_content["prompt"]
|
||||||
|
|
||||||
|
inputs: Union[TokenInputs, MultiModalInputs]
|
||||||
|
if multi_modal_data := parsed_content.get("multi_modal_data"):
|
||||||
|
inputs = await self._process_multimodal_async(
|
||||||
|
prompt_text,
|
||||||
|
multi_modal_data,
|
||||||
|
parsed_content.get("mm_processor_kwargs"),
|
||||||
|
lora_request=lora_request,
|
||||||
|
return_mm_hashes=return_mm_hashes,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt_token_ids = await self._tokenize_prompt_async(
|
||||||
|
prompt_text,
|
||||||
|
lora_request=lora_request,
|
||||||
|
tokenization_kwargs=tokenization_kwargs,
|
||||||
|
)
|
||||||
|
inputs = token_inputs(
|
||||||
|
prompt=prompt_text,
|
||||||
|
prompt_token_ids=prompt_token_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cache_salt := parsed_content.get("cache_salt"):
|
||||||
|
inputs["cache_salt"] = cache_salt
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
|
||||||
def _prompt_to_llm_inputs(
|
def _prompt_to_llm_inputs(
|
||||||
self,
|
self,
|
||||||
@ -333,38 +477,27 @@ class InputPreprocessor:
|
|||||||
parsed = parse_singleton_prompt(prompt)
|
parsed = parse_singleton_prompt(prompt)
|
||||||
|
|
||||||
if parsed["type"] == "embeds":
|
if parsed["type"] == "embeds":
|
||||||
return self._process_prompt_embeds(parsed)
|
return self._process_embeds(parsed["content"])
|
||||||
|
if parsed["type"] == "tokens":
|
||||||
prompt_text, prompt_token_ids, token_type_ids, cache_salt = \
|
return self._process_tokens(
|
||||||
self._get_prompt_data(parsed)
|
parsed["content"],
|
||||||
|
|
||||||
# If multimodal data is present, process and return immediately
|
|
||||||
if parsed["type"] != "str" and parsed["content"].get(
|
|
||||||
"multi_modal_data") is not None:
|
|
||||||
inputs = self._process_multimodal(
|
|
||||||
prompt_text if prompt_text is not None else prompt_token_ids,
|
|
||||||
parsed["content"]["multi_modal_data"],
|
|
||||||
parsed["content"].get("mm_processor_kwargs"),
|
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
return_mm_hashes=return_mm_hashes,
|
return_mm_hashes=return_mm_hashes,
|
||||||
)
|
)
|
||||||
if cache_salt is not None:
|
if parsed["type"] == "text":
|
||||||
inputs["cache_salt"] = cache_salt
|
return self._process_text(
|
||||||
return inputs
|
parsed["content"],
|
||||||
|
|
||||||
if prompt_token_ids is None:
|
|
||||||
prompt_token_ids = self._tokenize_prompt(
|
|
||||||
prompt_text,
|
|
||||||
lora_request=lora_request,
|
|
||||||
tokenization_kwargs=tokenization_kwargs,
|
tokenization_kwargs=tokenization_kwargs,
|
||||||
|
lora_request=lora_request,
|
||||||
|
return_mm_hashes=return_mm_hashes,
|
||||||
|
)
|
||||||
|
if parsed["type"] == "str":
|
||||||
|
return self._process_text(
|
||||||
|
TextPrompt(prompt=parsed["content"]),
|
||||||
|
tokenization_kwargs=tokenization_kwargs,
|
||||||
|
lora_request=lora_request,
|
||||||
|
return_mm_hashes=return_mm_hashes,
|
||||||
)
|
)
|
||||||
|
|
||||||
return token_inputs(
|
|
||||||
prompt=prompt_text,
|
|
||||||
prompt_token_ids=prompt_token_ids,
|
|
||||||
token_type_ids=token_type_ids,
|
|
||||||
cache_salt=cache_salt,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert_never(parsed)
|
assert_never(parsed)
|
||||||
|
|
||||||
@ -375,79 +508,49 @@ class InputPreprocessor:
|
|||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
return_mm_hashes: bool = False,
|
return_mm_hashes: bool = False,
|
||||||
) -> SingletonInputs:
|
) -> SingletonInputs:
|
||||||
"""Async version of :meth:`_extract_prompt_components`."""
|
"""Async version of :meth:`_prompt_to_llm_inputs`."""
|
||||||
parsed = parse_singleton_prompt(prompt)
|
parsed = parse_singleton_prompt(prompt)
|
||||||
|
|
||||||
if parsed["type"] == "embeds":
|
if parsed["type"] == "embeds":
|
||||||
return self._process_prompt_embeds(parsed)
|
return await self._process_embeds_async(parsed["content"])
|
||||||
|
if parsed["type"] == "tokens":
|
||||||
prompt_text, prompt_token_ids, token_type_ids, cache_salt = \
|
return await self._process_tokens_async(
|
||||||
self._get_prompt_data(parsed)
|
parsed["content"],
|
||||||
|
|
||||||
if parsed["type"] != "str" and parsed["content"].get(
|
|
||||||
"multi_modal_data") is not None:
|
|
||||||
inputs = await self._process_multimodal_async(
|
|
||||||
prompt_token_ids if prompt_text is None else prompt_text,
|
|
||||||
parsed["content"]["multi_modal_data"],
|
|
||||||
parsed["content"].get("mm_processor_kwargs"),
|
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
return_mm_hashes=return_mm_hashes,
|
return_mm_hashes=return_mm_hashes,
|
||||||
)
|
)
|
||||||
if cache_salt is not None:
|
if parsed["type"] == "text":
|
||||||
inputs["cache_salt"] = cache_salt
|
return await self._process_text_async(
|
||||||
return inputs
|
parsed["content"],
|
||||||
|
|
||||||
if prompt_token_ids is None:
|
|
||||||
prompt_token_ids = await self._tokenize_prompt_async(
|
|
||||||
prompt_text,
|
|
||||||
lora_request=lora_request,
|
|
||||||
tokenization_kwargs=tokenization_kwargs,
|
tokenization_kwargs=tokenization_kwargs,
|
||||||
|
lora_request=lora_request,
|
||||||
|
return_mm_hashes=return_mm_hashes,
|
||||||
|
)
|
||||||
|
if parsed["type"] == "str":
|
||||||
|
return await self._process_text_async(
|
||||||
|
TextPrompt(prompt=parsed["content"]),
|
||||||
|
tokenization_kwargs=tokenization_kwargs,
|
||||||
|
lora_request=lora_request,
|
||||||
|
return_mm_hashes=return_mm_hashes,
|
||||||
)
|
)
|
||||||
|
|
||||||
return token_inputs(
|
|
||||||
prompt=prompt_text,
|
|
||||||
prompt_token_ids=prompt_token_ids,
|
|
||||||
token_type_ids=token_type_ids,
|
|
||||||
cache_salt=cache_salt,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _process_prompt_embeds(self,
|
|
||||||
parsed: ParsedEmbedsPrompt) -> EmbedsInputs:
|
|
||||||
if envs.VLLM_USE_V1:
|
|
||||||
raise ValueError("prompt_embeds is only available in V0.")
|
|
||||||
|
|
||||||
prompt_embeds_content = parsed["content"]
|
|
||||||
|
|
||||||
prompt_embeds = prompt_embeds_content["prompt_embeds"]
|
|
||||||
|
|
||||||
# prompt_embeds must be (seq_len, hidden_size), but if the user
|
|
||||||
# passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
|
|
||||||
# we can unambiguously process the intent by squeezing the batch
|
|
||||||
# dimension.
|
|
||||||
if prompt_embeds.ndim == 3 and prompt_embeds.shape[0] == 1:
|
|
||||||
prompt_embeds = prompt_embeds.squeeze(dim=0)
|
|
||||||
|
|
||||||
if prompt_embeds.ndim != 2:
|
|
||||||
raise ValueError(
|
|
||||||
"prompt_embeds must be of shape (seq_len, hidden_size).")
|
|
||||||
|
|
||||||
return embeds_inputs(prompt_embeds=prompt_embeds)
|
|
||||||
|
|
||||||
assert_never(parsed)
|
assert_never(parsed)
|
||||||
|
|
||||||
def _build_enc_dec_llm_inputs(
|
def _build_enc_dec_llm_inputs(
|
||||||
self,
|
self,
|
||||||
encoder_inputs: Union[TokenInputs, MultiModalInputs],
|
encoder_inputs: SingletonInputs,
|
||||||
decoder_inputs: Optional[Union[TokenInputs, MultiModalInputs]],
|
decoder_inputs: Optional[SingletonInputs],
|
||||||
) -> EncoderDecoderInputs:
|
) -> EncoderDecoderInputs:
|
||||||
if (encoder_inputs["type"] == "token"
|
if (encoder_inputs["type"] == "embeds"
|
||||||
or encoder_inputs["type"] == "multimodal"):
|
or decoder_inputs and decoder_inputs["type"] == "embeds"):
|
||||||
pass
|
raise ValueError("Embedding inputs are not supported for encoder-"
|
||||||
else:
|
"decoder models")
|
||||||
assert_never(encoder_inputs) # type: ignore[arg-type]
|
|
||||||
|
|
||||||
# Mypy does not correctly infer that EmbedsInputs is impossible
|
# Needed for mypy
|
||||||
assert "prompt_token_ids" in encoder_inputs
|
encoder_inputs = cast(Union[TokenInputs, MultiModalInputs],
|
||||||
|
encoder_inputs)
|
||||||
|
decoder_inputs = cast(Optional[Union[TokenInputs, MultiModalInputs]],
|
||||||
|
decoder_inputs)
|
||||||
|
|
||||||
if decoder_inputs is None:
|
if decoder_inputs is None:
|
||||||
if self.model_config.hf_config.model_type == "whisper":
|
if self.model_config.hf_config.model_type == "whisper":
|
||||||
@ -460,74 +563,78 @@ class InputPreprocessor:
|
|||||||
dec_token_ids = self._prepare_decoder_input_ids_for_generation(
|
dec_token_ids = self._prepare_decoder_input_ids_for_generation(
|
||||||
None)
|
None)
|
||||||
decoder_inputs = token_inputs(dec_token_ids)
|
decoder_inputs = token_inputs(dec_token_ids)
|
||||||
elif (decoder_inputs["type"] == "token"
|
else:
|
||||||
or decoder_inputs["type"] == "multimodal"):
|
|
||||||
dec_token_ids = self._prepare_decoder_input_ids_for_generation(
|
|
||||||
decoder_inputs["prompt_token_ids"])
|
|
||||||
decoder_inputs["prompt_token_ids"] = dec_token_ids
|
|
||||||
|
|
||||||
if "multi_modal_data" in decoder_inputs:
|
if "multi_modal_data" in decoder_inputs:
|
||||||
raise ValueError("Multi-modal decoder inputs of encoder-"
|
raise ValueError("Multi-modal decoder inputs of encoder-"
|
||||||
"decoder models are not supported yet")
|
"decoder models are not supported yet")
|
||||||
else:
|
|
||||||
assert_never(encoder_inputs) # type: ignore[arg-type]
|
dec_token_ids = self._prepare_decoder_input_ids_for_generation(
|
||||||
|
decoder_inputs["prompt_token_ids"])
|
||||||
|
decoder_inputs["prompt_token_ids"] = dec_token_ids
|
||||||
|
|
||||||
return EncoderDecoderInputs(
|
return EncoderDecoderInputs(
|
||||||
encoder=encoder_inputs,
|
encoder=encoder_inputs,
|
||||||
decoder=decoder_inputs,
|
decoder=decoder_inputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _separate_enc_dec_inputs_from_mm_processor_outputs(
|
def _split_enc_dec_mm_inputs(
|
||||||
self,
|
self,
|
||||||
inputs: SingletonInputs,
|
inputs: Union[SingletonInputs, MultiModalEncDecInputs],
|
||||||
decoder_inputs_to_override: Optional[Union[TokenInputs,
|
decoder_inputs_to_override: Optional[SingletonInputs] = None,
|
||||||
MultiModalInputs]] = None,
|
|
||||||
) -> tuple[SingletonInputs, SingletonInputs]:
|
) -> tuple[SingletonInputs, SingletonInputs]:
|
||||||
"""
|
"""
|
||||||
For encoder/decoder models only:
|
For encoder/decoder models only:
|
||||||
Separate Encoder/Decoder inputs from a MultiModalEncDecInputs
|
Separate Encoder/Decoder inputs from a MultiModalEncDecInputs
|
||||||
"""
|
"""
|
||||||
|
if (inputs["type"] == "embeds" or decoder_inputs_to_override
|
||||||
|
and decoder_inputs_to_override["type"] == "embeds"):
|
||||||
|
raise ValueError("Embedding inputs are not supported for encoder-"
|
||||||
|
"decoder models")
|
||||||
|
|
||||||
|
# Needed for mypy
|
||||||
|
inputs = cast(
|
||||||
|
Union[TokenInputs, MultiModalInputs, MultiModalEncDecInputs],
|
||||||
|
inputs,
|
||||||
|
)
|
||||||
|
decoder_inputs_to_override = cast(
|
||||||
|
Optional[Union[TokenInputs, MultiModalInputs]],
|
||||||
|
decoder_inputs_to_override,
|
||||||
|
)
|
||||||
|
|
||||||
encoder_inputs: SingletonInputs
|
encoder_inputs: SingletonInputs
|
||||||
decoder_inputs: SingletonInputs
|
decoder_inputs: SingletonInputs
|
||||||
if inputs["type"] == "multimodal":
|
|
||||||
# Multimodal data inputs
|
if inputs["type"] == "multimodal": # Multimodal data inputs
|
||||||
assert ("encoder_prompt" in inputs
|
if not ("encoder_prompt" in inputs
|
||||||
and "encoder_prompt_token_ids" in inputs)
|
and "encoder_prompt_token_ids" in inputs):
|
||||||
|
raise RuntimeError("You should register an encoder-decoder "
|
||||||
|
"multi-modal processor for encoder-decoder "
|
||||||
|
"models.")
|
||||||
inputs = cast(MultiModalEncDecInputs, inputs)
|
inputs = cast(MultiModalEncDecInputs, inputs)
|
||||||
|
|
||||||
encoder_inputs = token_inputs(
|
encoder_inputs = token_inputs(
|
||||||
prompt=inputs["encoder_prompt"],
|
prompt=inputs["encoder_prompt"],
|
||||||
prompt_token_ids=inputs["encoder_prompt_token_ids"],
|
prompt_token_ids=inputs["encoder_prompt_token_ids"],
|
||||||
)
|
)
|
||||||
if decoder_inputs_to_override is not None:
|
|
||||||
decoder_inputs = MultiModalInputs(
|
|
||||||
type="multimodal",
|
|
||||||
prompt=decoder_inputs_to_override.get("prompt", ""),
|
|
||||||
prompt_token_ids=decoder_inputs_to_override[
|
|
||||||
"prompt_token_ids"],
|
|
||||||
mm_kwargs=inputs["mm_kwargs"],
|
|
||||||
mm_hashes=inputs["mm_hashes"],
|
|
||||||
mm_placeholders=inputs["mm_placeholders"],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
decoder_inputs = MultiModalInputs(
|
|
||||||
type="multimodal",
|
|
||||||
prompt=inputs["prompt"],
|
|
||||||
prompt_token_ids=inputs["prompt_token_ids"],
|
|
||||||
mm_kwargs=inputs["mm_kwargs"],
|
|
||||||
mm_hashes=inputs["mm_hashes"],
|
|
||||||
mm_placeholders=inputs["mm_placeholders"],
|
|
||||||
)
|
|
||||||
|
|
||||||
cache_salt = inputs.get("cache_salt")
|
decoder_prompt_inputs = decoder_inputs_to_override or inputs
|
||||||
if cache_salt is not None:
|
decoder_inputs = MultiModalInputs(
|
||||||
|
type="multimodal",
|
||||||
|
prompt=decoder_prompt_inputs.get("prompt", ""),
|
||||||
|
prompt_token_ids=decoder_prompt_inputs["prompt_token_ids"],
|
||||||
|
mm_kwargs=inputs["mm_kwargs"],
|
||||||
|
mm_hashes=inputs["mm_hashes"],
|
||||||
|
mm_placeholders=inputs["mm_placeholders"],
|
||||||
|
)
|
||||||
|
if cache_salt := inputs.get("cache_salt"):
|
||||||
decoder_inputs["cache_salt"] = cache_salt
|
decoder_inputs["cache_salt"] = cache_salt
|
||||||
|
|
||||||
elif inputs["type"] == "token":
|
elif inputs["type"] == "token": # Text-only inputs
|
||||||
# Text-only inputs
|
|
||||||
encoder_inputs = token_inputs(prompt="", prompt_token_ids=[])
|
encoder_inputs = token_inputs(prompt="", prompt_token_ids=[])
|
||||||
decoder_inputs = decoder_inputs_to_override or inputs
|
decoder_inputs = decoder_inputs_to_override or inputs
|
||||||
else:
|
else:
|
||||||
assert_never(inputs) # type: ignore[arg-type]
|
assert_never(inputs) # type: ignore[arg-type]
|
||||||
|
|
||||||
return encoder_inputs, decoder_inputs
|
return encoder_inputs, decoder_inputs
|
||||||
|
|
||||||
def _process_encoder_decoder_prompt(
|
def _process_encoder_decoder_prompt(
|
||||||
@ -580,11 +687,9 @@ class InputPreprocessor:
|
|||||||
# For multimodal model, override decoder prompt from processor
|
# For multimodal model, override decoder prompt from processor
|
||||||
# with explicit decoder prompt.
|
# with explicit decoder prompt.
|
||||||
if self.model_config.is_multimodal_model:
|
if self.model_config.is_multimodal_model:
|
||||||
assert decoder_inputs is None or not is_embeds_inputs(
|
|
||||||
decoder_inputs)
|
|
||||||
encoder_inputs, decoder_inputs = (
|
encoder_inputs, decoder_inputs = (
|
||||||
self._separate_enc_dec_inputs_from_mm_processor_outputs(
|
self._split_enc_dec_mm_inputs(encoder_inputs,
|
||||||
encoder_inputs, decoder_inputs))
|
decoder_inputs))
|
||||||
else:
|
else:
|
||||||
inputs = self._prompt_to_llm_inputs(
|
inputs = self._prompt_to_llm_inputs(
|
||||||
prompt,
|
prompt,
|
||||||
@ -593,16 +698,11 @@ class InputPreprocessor:
|
|||||||
if self.model_config.is_multimodal_model:
|
if self.model_config.is_multimodal_model:
|
||||||
# Encoder-Decoder Multimodal model
|
# Encoder-Decoder Multimodal model
|
||||||
encoder_inputs, decoder_inputs = (
|
encoder_inputs, decoder_inputs = (
|
||||||
self._separate_enc_dec_inputs_from_mm_processor_outputs(
|
self._split_enc_dec_mm_inputs(inputs))
|
||||||
inputs))
|
|
||||||
else:
|
else:
|
||||||
encoder_inputs = inputs
|
encoder_inputs = inputs
|
||||||
decoder_inputs = None
|
decoder_inputs = None
|
||||||
|
|
||||||
# Mypy does not do type inference well with TypedDicts with Literal
|
|
||||||
# values.
|
|
||||||
assert not is_embeds_inputs(encoder_inputs)
|
|
||||||
assert decoder_inputs is None or not is_embeds_inputs(decoder_inputs)
|
|
||||||
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
|
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
|
||||||
|
|
||||||
async def _process_encoder_decoder_prompt_async(
|
async def _process_encoder_decoder_prompt_async(
|
||||||
@ -635,11 +735,9 @@ class InputPreprocessor:
|
|||||||
# For multimodal model, override decoder prompt from processor
|
# For multimodal model, override decoder prompt from processor
|
||||||
# with explicit decoder prompt.
|
# with explicit decoder prompt.
|
||||||
if self.model_config.is_multimodal_model:
|
if self.model_config.is_multimodal_model:
|
||||||
assert decoder_inputs is None or not is_embeds_inputs(
|
|
||||||
decoder_inputs)
|
|
||||||
encoder_inputs, decoder_inputs = (
|
encoder_inputs, decoder_inputs = (
|
||||||
self._separate_enc_dec_inputs_from_mm_processor_outputs(
|
self._split_enc_dec_mm_inputs(encoder_inputs,
|
||||||
encoder_inputs, decoder_inputs))
|
decoder_inputs))
|
||||||
else:
|
else:
|
||||||
inputs = await self._prompt_to_llm_inputs_async(
|
inputs = await self._prompt_to_llm_inputs_async(
|
||||||
prompt,
|
prompt,
|
||||||
@ -648,16 +746,11 @@ class InputPreprocessor:
|
|||||||
if self.model_config.is_multimodal_model:
|
if self.model_config.is_multimodal_model:
|
||||||
# Encoder-Decoder Multimodal model
|
# Encoder-Decoder Multimodal model
|
||||||
encoder_inputs, decoder_inputs = (
|
encoder_inputs, decoder_inputs = (
|
||||||
self._separate_enc_dec_inputs_from_mm_processor_outputs(
|
self._split_enc_dec_mm_inputs(inputs))
|
||||||
inputs))
|
|
||||||
else:
|
else:
|
||||||
encoder_inputs = inputs
|
encoder_inputs = inputs
|
||||||
decoder_inputs = None
|
decoder_inputs = None
|
||||||
|
|
||||||
# Mypy does not do type inference well with TypedDicts with Literal
|
|
||||||
# values.
|
|
||||||
assert not is_embeds_inputs(encoder_inputs)
|
|
||||||
assert decoder_inputs is None or not is_embeds_inputs(decoder_inputs)
|
|
||||||
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
|
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
|
||||||
|
|
||||||
def _build_decoder_only_llm_inputs(
|
def _build_decoder_only_llm_inputs(
|
||||||
@ -665,19 +758,13 @@ class InputPreprocessor:
|
|||||||
prompt_inputs: DecoderOnlyInputs,
|
prompt_inputs: DecoderOnlyInputs,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||||
) -> DecoderOnlyInputs:
|
) -> DecoderOnlyInputs:
|
||||||
if (prompt_inputs["type"] == "token"
|
if "prompt_token_ids" in prompt_inputs:
|
||||||
or prompt_inputs["type"] == "multimodal"):
|
prompt_inputs = cast(Union[TokenInputs, MultiModalInputs],
|
||||||
# Mypy does not do type inference well with typedicts and Literal
|
prompt_inputs) # Needed for mypy
|
||||||
# values
|
|
||||||
assert not is_embeds_inputs(prompt_inputs)
|
|
||||||
prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter(
|
prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter(
|
||||||
prompt_inputs["prompt_token_ids"],
|
prompt_inputs["prompt_token_ids"],
|
||||||
prompt_adapter_request=prompt_adapter_request,
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
)
|
)
|
||||||
elif (prompt_inputs["type"] == "embeds"):
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
assert_never(prompt_inputs) # type: ignore[arg-type]
|
|
||||||
|
|
||||||
return prompt_inputs
|
return prompt_inputs
|
||||||
|
|
||||||
|
|||||||
@ -1670,15 +1670,17 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
placeholders = mm_placeholders.get(modality, [])
|
placeholders = mm_placeholders.get(modality, [])
|
||||||
|
|
||||||
if len(placeholders) != item_count:
|
if len(placeholders) != item_count:
|
||||||
|
# NOTE: If you are a model developer, this can also arise from
|
||||||
|
# an inconsistency between `_call_hf_processor` and
|
||||||
|
# `_get_mm_fields_config` implementations
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Expected there to be {item_count} prompt updates "
|
f"Expected there to be {item_count} prompt updates "
|
||||||
f"corresponding to {item_count} {modality} items, but "
|
f"corresponding to {item_count} {modality} items, but "
|
||||||
f"instead found {len(placeholders)} prompt updates! "
|
f"instead found {len(placeholders)} prompt updates! "
|
||||||
"Either the prompt text has missing/incorrect tokens for "
|
"This is likely because you forgot to include input "
|
||||||
"multi-modal inputs, or there is a problem with your "
|
"placeholder tokens (e.g., `<image>`, `<|image_pad|>`) "
|
||||||
"implementation of merged multi-modal processor for this "
|
"in the prompt. If the model has a chat template, make "
|
||||||
"model (usually arising from an inconsistency between "
|
"sure you have applied it before calling `LLM.generate`.")
|
||||||
"`_call_hf_processor` and `_get_prompt_updates`).")
|
|
||||||
|
|
||||||
def _maybe_apply_prompt_updates(
|
def _maybe_apply_prompt_updates(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user