mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 11:41:58 +08:00
[Optimization] Streamline InputPreprocessor (#25702)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
6b0fcbbf43
commit
3d54bdcb73
@ -1,7 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any, Optional, Union, cast
|
from typing import Any, Optional, Union, cast
|
||||||
|
|
||||||
@ -13,6 +12,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
|||||||
from vllm.multimodal.cache import BaseMultiModalProcessorCache
|
from vllm.multimodal.cache import BaseMultiModalProcessorCache
|
||||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
|
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
|
||||||
MultiModalInputs, MultiModalUUIDDict)
|
MultiModalInputs, MultiModalUUIDDict)
|
||||||
|
from vllm.multimodal.processing import BaseMultiModalProcessor
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
|
||||||
from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
|
from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
|
||||||
@ -200,20 +200,6 @@ class InputPreprocessor:
|
|||||||
|
|
||||||
return tokenizer.encode(prompt, **tokenization_kwargs)
|
return tokenizer.encode(prompt, **tokenization_kwargs)
|
||||||
|
|
||||||
async def _tokenize_prompt_async(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
|
||||||
) -> list[int]:
|
|
||||||
"""
|
|
||||||
Async version of
|
|
||||||
[`_tokenize_prompt`][vllm.inputs.preprocess.InputPreprocessor._tokenize_prompt].
|
|
||||||
"""
|
|
||||||
tokenizer = self.get_tokenizer()
|
|
||||||
tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
|
|
||||||
|
|
||||||
return tokenizer.encode(prompt, **tokenization_kwargs)
|
|
||||||
|
|
||||||
def _get_mm_tokenizer(self) -> AnyTokenizer:
|
def _get_mm_tokenizer(self) -> AnyTokenizer:
|
||||||
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer
|
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer
|
||||||
# while using also multi-modal input
|
# while using also multi-modal input
|
||||||
@ -223,14 +209,17 @@ class InputPreprocessor:
|
|||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
async def _get_mm_tokenizer_async(self) -> AnyTokenizer:
|
def _get_mm_processor(self) -> BaseMultiModalProcessor:
|
||||||
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer
|
if not hasattr(self, "_mm_processor"):
|
||||||
# while using also multi-modal input
|
tokenizer = self._get_mm_tokenizer()
|
||||||
if not self.tokenizer:
|
|
||||||
return cast(AnyTokenizer, object()) # Dummy
|
|
||||||
|
|
||||||
tokenizer = self.get_tokenizer()
|
self._mm_processor = self.mm_registry.create_processor(
|
||||||
return tokenizer
|
self.model_config,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
cache=self.mm_processor_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._mm_processor
|
||||||
|
|
||||||
def _process_multimodal(
|
def _process_multimodal(
|
||||||
self,
|
self,
|
||||||
@ -245,55 +234,7 @@ class InputPreprocessor:
|
|||||||
Apply the model's multi-modal processor to a multi-modal prompt,
|
Apply the model's multi-modal processor to a multi-modal prompt,
|
||||||
returning the corresponding token IDs and metadata.
|
returning the corresponding token IDs and metadata.
|
||||||
"""
|
"""
|
||||||
tokenizer = self._get_mm_tokenizer()
|
mm_processor = self._get_mm_processor()
|
||||||
|
|
||||||
mm_processor = self.mm_registry.create_processor(
|
|
||||||
self.model_config,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
cache=self.mm_processor_cache,
|
|
||||||
)
|
|
||||||
|
|
||||||
if mm_processor_kwargs is None:
|
|
||||||
mm_processor_kwargs = {}
|
|
||||||
|
|
||||||
mm_input = mm_processor.apply(
|
|
||||||
prompt,
|
|
||||||
mm_data,
|
|
||||||
hf_processor_mm_kwargs=mm_processor_kwargs,
|
|
||||||
tokenization_kwargs=tokenization_kwargs,
|
|
||||||
mm_uuids=mm_uuids,
|
|
||||||
)
|
|
||||||
mm_hashes = mm_input["mm_hashes"]
|
|
||||||
|
|
||||||
# Validate that all mm items have a string as their hash
|
|
||||||
if not contains_only_strings(mm_hashes):
|
|
||||||
raise ValueError(
|
|
||||||
f"mm_hashes must contain only strings, got: {mm_hashes}. "
|
|
||||||
"This is likely due to an incorrect custom implementation of "
|
|
||||||
"MultiModalProcessor.apply method.")
|
|
||||||
|
|
||||||
return mm_input
|
|
||||||
|
|
||||||
async def _process_multimodal_async(
|
|
||||||
self,
|
|
||||||
prompt: Union[str, list[int]],
|
|
||||||
mm_data: MultiModalDataDict,
|
|
||||||
mm_processor_kwargs: Optional[Mapping[str, object]],
|
|
||||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
|
||||||
*,
|
|
||||||
mm_uuids: Optional[MultiModalUUIDDict] = None,
|
|
||||||
) -> MultiModalInputs:
|
|
||||||
"""
|
|
||||||
Async version of
|
|
||||||
[`_process_multimodal`][vllm.inputs.preprocess.InputPreprocessor._process_multimodal].
|
|
||||||
"""
|
|
||||||
tokenizer = await self._get_mm_tokenizer_async()
|
|
||||||
|
|
||||||
mm_processor = self.mm_registry.create_processor(
|
|
||||||
self.model_config,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
cache=self.mm_processor_cache,
|
|
||||||
)
|
|
||||||
|
|
||||||
if mm_processor_kwargs is None:
|
if mm_processor_kwargs is None:
|
||||||
mm_processor_kwargs = {}
|
mm_processor_kwargs = {}
|
||||||
@ -340,12 +281,6 @@ class InputPreprocessor:
|
|||||||
return embeds_inputs(prompt_embeds=prompt_embeds,
|
return embeds_inputs(prompt_embeds=prompt_embeds,
|
||||||
cache_salt=parsed_content.get("cache_salt"))
|
cache_salt=parsed_content.get("cache_salt"))
|
||||||
|
|
||||||
async def _process_embeds_async(
|
|
||||||
self,
|
|
||||||
parsed_content: EmbedsPrompt,
|
|
||||||
) -> EmbedsInputs:
|
|
||||||
return self._process_embeds(parsed_content)
|
|
||||||
|
|
||||||
def _truncate_inputs(
|
def _truncate_inputs(
|
||||||
self,
|
self,
|
||||||
inputs: list[int],
|
inputs: list[int],
|
||||||
@ -389,33 +324,6 @@ class InputPreprocessor:
|
|||||||
|
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
async def _process_tokens_async(
|
|
||||||
self,
|
|
||||||
parsed_content: TokensPrompt,
|
|
||||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
|
||||||
*,
|
|
||||||
mm_uuids: Optional[MultiModalUUIDDict] = None,
|
|
||||||
) -> Union[TokenInputs, MultiModalInputs]:
|
|
||||||
prompt_token_ids = self._truncate_inputs(
|
|
||||||
parsed_content["prompt_token_ids"], tokenization_kwargs)
|
|
||||||
|
|
||||||
inputs: Union[TokenInputs, MultiModalInputs]
|
|
||||||
if multi_modal_data := parsed_content.get("multi_modal_data"):
|
|
||||||
inputs = await self._process_multimodal_async(
|
|
||||||
prompt_token_ids,
|
|
||||||
multi_modal_data,
|
|
||||||
parsed_content.get("mm_processor_kwargs"),
|
|
||||||
tokenization_kwargs=tokenization_kwargs,
|
|
||||||
mm_uuids=mm_uuids,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
inputs = token_inputs(prompt_token_ids=prompt_token_ids, )
|
|
||||||
|
|
||||||
if cache_salt := parsed_content.get("cache_salt"):
|
|
||||||
inputs["cache_salt"] = cache_salt
|
|
||||||
|
|
||||||
return inputs
|
|
||||||
|
|
||||||
def _process_text(
|
def _process_text(
|
||||||
self,
|
self,
|
||||||
parsed_content: TextPrompt,
|
parsed_content: TextPrompt,
|
||||||
@ -449,39 +357,6 @@ class InputPreprocessor:
|
|||||||
|
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
async def _process_text_async(
|
|
||||||
self,
|
|
||||||
parsed_content: TextPrompt,
|
|
||||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
|
||||||
*,
|
|
||||||
mm_uuids: Optional[MultiModalUUIDDict] = None,
|
|
||||||
) -> Union[TokenInputs, MultiModalInputs]:
|
|
||||||
prompt_text = parsed_content["prompt"]
|
|
||||||
|
|
||||||
inputs: Union[TokenInputs, MultiModalInputs]
|
|
||||||
if multi_modal_data := parsed_content.get("multi_modal_data"):
|
|
||||||
inputs = await self._process_multimodal_async(
|
|
||||||
prompt_text,
|
|
||||||
multi_modal_data,
|
|
||||||
parsed_content.get("mm_processor_kwargs"),
|
|
||||||
tokenization_kwargs=tokenization_kwargs,
|
|
||||||
mm_uuids=mm_uuids,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
prompt_token_ids = await self._tokenize_prompt_async(
|
|
||||||
prompt_text,
|
|
||||||
tokenization_kwargs=tokenization_kwargs,
|
|
||||||
)
|
|
||||||
inputs = token_inputs(
|
|
||||||
prompt=prompt_text,
|
|
||||||
prompt_token_ids=prompt_token_ids,
|
|
||||||
)
|
|
||||||
|
|
||||||
if cache_salt := parsed_content.get("cache_salt"):
|
|
||||||
inputs["cache_salt"] = cache_salt
|
|
||||||
|
|
||||||
return inputs
|
|
||||||
|
|
||||||
def _prompt_to_llm_inputs(
|
def _prompt_to_llm_inputs(
|
||||||
self,
|
self,
|
||||||
prompt: SingletonPrompt,
|
prompt: SingletonPrompt,
|
||||||
@ -524,41 +399,6 @@ class InputPreprocessor:
|
|||||||
|
|
||||||
assert_never(parsed)
|
assert_never(parsed)
|
||||||
|
|
||||||
async def _prompt_to_llm_inputs_async(
|
|
||||||
self,
|
|
||||||
prompt: SingletonPrompt,
|
|
||||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
|
||||||
*,
|
|
||||||
mm_uuids: Optional[MultiModalUUIDDict] = None,
|
|
||||||
) -> SingletonInputs:
|
|
||||||
"""
|
|
||||||
Async version of
|
|
||||||
[`_prompt_to_llm_inputs`][vllm.inputs.preprocess.InputPreprocessor._prompt_to_llm_inputs].
|
|
||||||
"""
|
|
||||||
parsed = parse_singleton_prompt(prompt)
|
|
||||||
|
|
||||||
if parsed["type"] == "embeds":
|
|
||||||
return await self._process_embeds_async(parsed["content"])
|
|
||||||
if parsed["type"] == "tokens":
|
|
||||||
return await self._process_tokens_async(
|
|
||||||
parsed["content"],
|
|
||||||
mm_uuids=mm_uuids,
|
|
||||||
)
|
|
||||||
if parsed["type"] == "text":
|
|
||||||
return await self._process_text_async(
|
|
||||||
parsed["content"],
|
|
||||||
tokenization_kwargs=tokenization_kwargs,
|
|
||||||
mm_uuids=mm_uuids,
|
|
||||||
)
|
|
||||||
if parsed["type"] == "str":
|
|
||||||
return await self._process_text_async(
|
|
||||||
TextPrompt(prompt=parsed["content"]),
|
|
||||||
tokenization_kwargs=tokenization_kwargs,
|
|
||||||
mm_uuids=mm_uuids,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert_never(parsed)
|
|
||||||
|
|
||||||
def _build_enc_dec_llm_inputs(
|
def _build_enc_dec_llm_inputs(
|
||||||
self,
|
self,
|
||||||
encoder_inputs: SingletonInputs,
|
encoder_inputs: SingletonInputs,
|
||||||
@ -735,62 +575,6 @@ class InputPreprocessor:
|
|||||||
|
|
||||||
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
|
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
|
||||||
|
|
||||||
async def _process_encoder_decoder_prompt_async(
|
|
||||||
self,
|
|
||||||
prompt: PromptType,
|
|
||||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
|
||||||
*,
|
|
||||||
mm_uuids: Optional[MultiModalUUIDDict] = None,
|
|
||||||
) -> EncoderDecoderInputs:
|
|
||||||
"""
|
|
||||||
Async version of
|
|
||||||
[`_process_encoder_decoder_prompt`][vllm.inputs.preprocess.InputPreprocessor._process_encoder_decoder_prompt].
|
|
||||||
"""
|
|
||||||
encoder_inputs: SingletonInputs
|
|
||||||
decoder_inputs: Optional[SingletonInputs]
|
|
||||||
|
|
||||||
if is_explicit_encoder_decoder_prompt(prompt):
|
|
||||||
encoder_task = self._prompt_to_llm_inputs_async(
|
|
||||||
prompt["encoder_prompt"],
|
|
||||||
tokenization_kwargs=tokenization_kwargs,
|
|
||||||
mm_uuids=mm_uuids,
|
|
||||||
)
|
|
||||||
|
|
||||||
if (decoder_input := prompt["decoder_prompt"]) is None:
|
|
||||||
encoder_inputs = await encoder_task
|
|
||||||
decoder_inputs = None
|
|
||||||
else:
|
|
||||||
decoder_task = self._prompt_to_llm_inputs_async(
|
|
||||||
decoder_input,
|
|
||||||
tokenization_kwargs=tokenization_kwargs,
|
|
||||||
mm_uuids=mm_uuids,
|
|
||||||
)
|
|
||||||
|
|
||||||
encoder_inputs, decoder_inputs = await asyncio.gather(
|
|
||||||
encoder_task, decoder_task)
|
|
||||||
|
|
||||||
# For multimodal model, override decoder prompt from processor
|
|
||||||
# with explicit decoder prompt.
|
|
||||||
if self.model_config.is_multimodal_model:
|
|
||||||
encoder_inputs, decoder_inputs = (
|
|
||||||
self._split_enc_dec_mm_inputs(encoder_inputs,
|
|
||||||
decoder_inputs))
|
|
||||||
else:
|
|
||||||
inputs = await self._prompt_to_llm_inputs_async(
|
|
||||||
prompt,
|
|
||||||
tokenization_kwargs=tokenization_kwargs,
|
|
||||||
mm_uuids=mm_uuids,
|
|
||||||
)
|
|
||||||
if self.model_config.is_multimodal_model:
|
|
||||||
# Encoder-Decoder Multimodal model
|
|
||||||
encoder_inputs, decoder_inputs = (
|
|
||||||
self._split_enc_dec_mm_inputs(inputs))
|
|
||||||
else:
|
|
||||||
encoder_inputs = inputs
|
|
||||||
decoder_inputs = None
|
|
||||||
|
|
||||||
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
|
|
||||||
|
|
||||||
def _build_decoder_only_llm_inputs(
|
def _build_decoder_only_llm_inputs(
|
||||||
self,
|
self,
|
||||||
prompt_inputs: DecoderOnlyInputs,
|
prompt_inputs: DecoderOnlyInputs,
|
||||||
@ -830,25 +614,6 @@ class InputPreprocessor:
|
|||||||
|
|
||||||
return self._build_decoder_only_llm_inputs(prompt_comps)
|
return self._build_decoder_only_llm_inputs(prompt_comps)
|
||||||
|
|
||||||
async def _process_decoder_only_prompt_async(
|
|
||||||
self,
|
|
||||||
prompt: SingletonPrompt,
|
|
||||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
|
||||||
*,
|
|
||||||
mm_uuids: Optional[MultiModalUUIDDict] = None,
|
|
||||||
) -> DecoderOnlyInputs:
|
|
||||||
"""
|
|
||||||
Async version of
|
|
||||||
[`_process_decoder_only_prompt`][vllm.inputs.preprocess.InputPreprocessor._process_decoder_only_prompt].
|
|
||||||
"""
|
|
||||||
prompt_comps = await self._prompt_to_llm_inputs_async(
|
|
||||||
prompt,
|
|
||||||
tokenization_kwargs=tokenization_kwargs,
|
|
||||||
mm_uuids=mm_uuids,
|
|
||||||
)
|
|
||||||
|
|
||||||
return self._build_decoder_only_llm_inputs(prompt_comps)
|
|
||||||
|
|
||||||
def preprocess(
|
def preprocess(
|
||||||
self,
|
self,
|
||||||
prompt: PromptType,
|
prompt: PromptType,
|
||||||
@ -877,37 +642,6 @@ class InputPreprocessor:
|
|||||||
mm_uuids=mm_uuids,
|
mm_uuids=mm_uuids,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def preprocess_async(
|
|
||||||
self,
|
|
||||||
prompt: PromptType,
|
|
||||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
|
||||||
*,
|
|
||||||
mm_uuids: Optional[MultiModalUUIDDict] = None,
|
|
||||||
) -> ProcessorInputs:
|
|
||||||
"""
|
|
||||||
Async version of
|
|
||||||
[`preprocess`][vllm.inputs.preprocess.InputPreprocessor.preprocess].
|
|
||||||
"""
|
|
||||||
if self.model_config.is_encoder_decoder:
|
|
||||||
# Encoder-decoder model requires special mapping of
|
|
||||||
# input prompts to encoder & decoder.
|
|
||||||
return await self._process_encoder_decoder_prompt_async(
|
|
||||||
prompt,
|
|
||||||
tokenization_kwargs,
|
|
||||||
mm_uuids=mm_uuids,
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_explicit_encoder_decoder_prompt(prompt):
|
|
||||||
raise ValueError("Cannot pass encoder-decoder prompt "
|
|
||||||
"to decoder-only models")
|
|
||||||
|
|
||||||
# Decoder-only operation
|
|
||||||
return await self._process_decoder_only_prompt_async(
|
|
||||||
prompt,
|
|
||||||
tokenization_kwargs=tokenization_kwargs,
|
|
||||||
mm_uuids=mm_uuids,
|
|
||||||
)
|
|
||||||
|
|
||||||
def clear_cache(self) -> None:
|
def clear_cache(self) -> None:
|
||||||
if self.mm_processor_cache is not None:
|
if self.mm_processor_cache is not None:
|
||||||
self.mm_processor_cache.clear_cache()
|
self.mm_processor_cache.clear_cache()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user