[Core] Support serving encoder/decoder models (#7258)

This commit is contained in:
Cyrus Leung 2024-08-09 10:39:41 +08:00 committed by GitHub
parent 0fa14907da
commit 7eb4a51c5f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 603 additions and 464 deletions

View File

@ -25,7 +25,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install mypy==1.9.0
pip install mypy==1.11.1
pip install types-setuptools
pip install types-PyYAML
pip install types-requests

View File

@ -4,8 +4,8 @@ encoder/decoder models, specifically BART
'''
from vllm import LLM, SamplingParams
from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt
from vllm.utils import zip_enc_dec_prompt_lists
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
TokensPrompt, zip_enc_dec_prompts)
dtype = "float"
@ -61,9 +61,9 @@ enc_dec_prompt3 = ExplicitEncoderDecoderPrompt(
)
# - Finally, here's a useful helper function for zipping encoder and
# decoder prompt lists together into a list of ExplicitEncoderDecoderPrompt
# decoder prompts together into a list of ExplicitEncoderDecoderPrompt
# instances
zipped_prompt_list = zip_enc_dec_prompt_lists(
zipped_prompt_list = zip_enc_dec_prompts(
['An encoder prompt', 'Another encoder prompt'],
['A decoder prompt', 'Another decoder prompt'])

View File

@ -19,7 +19,7 @@ prometheus-fastapi-instrumentator >= 7.0.0
tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer == 0.10.3
outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0
typing_extensions
typing_extensions >= 4.10
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
pyzmq
gguf == 0.9.1

View File

@ -8,7 +8,7 @@ isort==5.13.2
clang-format==18.1.5
# type checking
mypy==1.9.0
mypy==1.11.1
types-PyYAML
types-requests
types-setuptools

View File

@ -3,6 +3,7 @@ import gc
import os
import sys
from collections import UserList
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar, Union
import pytest
@ -14,20 +15,19 @@ from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM,
AutoModelForVision2Seq, AutoTokenizer, BatchEncoding,
BatchFeature)
from tests.models.utils import DecoderPromptType
from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
from vllm.config import TokenizerPoolConfig
from vllm.connections import global_http_connection
from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel)
from vllm.inputs import TextPrompt
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
to_enc_dec_tuple_list, zip_enc_dec_prompts)
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sequence import SampleLogprobs
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
is_cpu, to_enc_dec_tuple_list,
zip_enc_dec_prompt_lists)
is_cpu)
logger = init_logger(__name__)
@ -124,10 +124,16 @@ def example_prompts() -> List[str]:
return prompts
class DecoderPromptType(Enum):
"""For encoder/decoder models only."""
CUSTOM = 1
NONE = 2
EMPTY_STR = 3
@pytest.fixture
def example_encoder_decoder_prompts() \
-> Dict[DecoderPromptType,
Tuple[List[str], List[Optional[str]]]]:
def example_encoder_decoder_prompts(
) -> Dict[DecoderPromptType, List[ExplicitEncoderDecoderPrompt]]:
'''
Returns an encoder prompt list and a decoder prompt list, wherein each pair
of same-index entries in both lists corresponds to an (encoder prompt,
@ -150,11 +156,11 @@ def example_encoder_decoder_prompts() \
# NONE decoder prompt type
return {
DecoderPromptType.NONE:
zip_enc_dec_prompt_lists(encoder_prompts, none_decoder_prompts),
zip_enc_dec_prompts(encoder_prompts, none_decoder_prompts),
DecoderPromptType.EMPTY_STR:
zip_enc_dec_prompt_lists(encoder_prompts, empty_str_decoder_prompts),
zip_enc_dec_prompts(encoder_prompts, empty_str_decoder_prompts),
DecoderPromptType.CUSTOM:
zip_enc_dec_prompt_lists(encoder_prompts, custom_decoder_prompts),
zip_enc_dec_prompts(encoder_prompts, custom_decoder_prompts),
}
@ -444,7 +450,7 @@ class HfRunner:
def generate_encoder_decoder_greedy_logprobs_limit(
self,
encoder_decoder_prompts: Tuple[List[str], List[str]],
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
max_tokens: int,
num_logprobs: int,
**kwargs: Any,
@ -608,7 +614,7 @@ class VllmRunner:
def generate_encoder_decoder_w_logprobs(
self,
encoder_decoder_prompts: Tuple[List[str], List[str]],
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
sampling_params: SamplingParams,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
'''
@ -653,7 +659,7 @@ class VllmRunner:
def generate_encoder_decoder_greedy_logprobs(
self,
encoder_decoder_prompts: Tuple[List[str], List[str]],
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
max_tokens: int,
num_logprobs: int,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:

View File

@ -11,9 +11,9 @@ pytest distributed/test_basic_distributed_correctness_enc_dec.py
import pytest
from tests.models.utils import DecoderPromptType
from vllm.utils import cuda_device_count_stateless
from ..conftest import DecoderPromptType
from ..models.utils import check_logprobs_close
from ..utils import fork_new_process_for_each_test

View File

@ -0,0 +1,50 @@
import openai
import pytest
from ...utils import RemoteOpenAIServer
MODEL_NAME = "facebook/bart-base"
@pytest.fixture(scope="module")
def server():
args = [
"--dtype",
"bfloat16",
"--enforce-eager",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest.fixture(scope="module")
def client(server):
return server.get_async_client()
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
completion = await client.completions.create(model=model_name,
prompt="Hello, my name is",
max_tokens=5,
temperature=0.0)
assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 1
choice = completion.choices[0]
assert len(choice.text) >= 5
assert choice.finish_reason == "length"
assert completion.usage == openai.types.CompletionUsage(
completion_tokens=5, prompt_tokens=2, total_tokens=7)
# test using token IDs
completion = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
)
assert len(completion.choices[0].text) >= 1

View File

@ -2,6 +2,8 @@
Run `pytest tests/models/test_bart.py`.
"""
from typing import List, Optional, Tuple
from vllm.utils import is_cpu
if not is_cpu():
@ -11,22 +13,31 @@ if not is_cpu():
import pytest
from tests.models.utils import DecoderPromptType
from vllm.sequence import SampleLogprobs
from ..conftest import DecoderPromptType
from .utils import check_logprobs_close
MODELS = ["facebook/bart-base", "facebook/bart-large-cnn"]
DECODER_PROMPT_TYPES = ([
DecoderPromptType.CUSTOM, DecoderPromptType.EMPTY_STR,
DecoderPromptType.NONE
])
def vllm_to_hf_output(
vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]],
decoder_prompt_type: DecoderPromptType,
):
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output
hf_output_str = output_str + "</s>"
if decoder_prompt_type == DecoderPromptType.NONE:
hf_output_str = "<s>" + hf_output_str
return output_ids, hf_output_str, out_logprobs
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float", "bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("decoder_prompt_type", DECODER_PROMPT_TYPES)
@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
def test_models(
hf_runner,
vllm_runner,
@ -146,8 +157,13 @@ if not is_cpu():
hf_skip_tokens = (1 if decoder_prompt_type == DecoderPromptType.NONE
else 0)
check_logprobs_close(outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
num_outputs_0_skip_tokens=hf_skip_tokens)
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, decoder_prompt_type)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
num_outputs_0_skip_tokens=hf_skip_tokens,
)

View File

@ -1,5 +1,4 @@
import warnings
from enum import Enum
from typing import Dict, List, Optional, Sequence, Tuple, Union
from vllm.sequence import SampleLogprobs
@ -136,13 +135,3 @@ def check_logprobs_close(
warnings.simplefilter("always")
warnings.warn(fail_msg, stacklevel=2)
class DecoderPromptType(Enum):
'''
For encoder/decoder models only -
'''
CUSTOM = 1
NONE = 2
EMPTY_STR = 3

View File

@ -2,7 +2,7 @@ from typing import List
import pytest
from vllm.inputs import parse_and_batch_prompt
from vllm.inputs.parse import parse_and_batch_prompt
STRING_INPUTS = [
'',

View File

@ -464,6 +464,16 @@ class ModelConfig:
if t != "attention"
])
@property
def is_encoder_decoder_model(self) -> bool:
"""Extract the HF encoder/decoder model flag."""
return getattr(self.hf_config, "is_encoder_decoder", False)
@property
def is_embedding_model(self) -> bool:
"""Extract the embedding model flag."""
return self.embedding_mode
class CacheConfig:
"""Configuration for the KV cache.

View File

@ -5,6 +5,7 @@ from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping,
Optional, Set, Tuple, Type, Union)
from transformers import PreTrainedTokenizer
from typing_extensions import assert_never
import vllm.envs as envs
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
@ -12,11 +13,14 @@ from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine,
PromptComponents)
from vllm.engine.metrics import StatLoggerBase
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.ray_utils import initialize_ray_cluster, ray
from vllm.inputs import LLMInputs, PromptInputs
from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
SingletonPromptInputs)
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
@ -293,38 +297,138 @@ class _AsyncLLMEngine(LLMEngine):
"""Stop the remote worker execution loop."""
await self.model_executor.stop_remote_worker_execution_loop_async()
async def process_model_inputs_async(
async def _tokenize_prompt_async(
self,
prompt: str,
request_id: str,
lora_request: Optional[LoRARequest],
) -> List[int]:
"""Async version of :meth:`_tokenize_prompt`."""
tokenizer = self.get_tokenizer_group("prompts must be None if "
"skip_tokenizer_init is True")
return await tokenizer.encode_async(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
async def _extract_prompt_components_async(
self,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
) -> PromptComponents:
"""Async version of :meth:`_extract_prompt_components`."""
if isinstance(inputs, str):
prompt = inputs
prompt_token_ids = await self._tokenize_prompt_async(
prompt,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = None
elif isinstance(inputs, dict):
if "prompt_token_ids" in inputs:
prompt = None
prompt_token_ids = inputs["prompt_token_ids"]
else:
# NOTE: This extra assignment is required to pass mypy
prompt = parsed_prompt = inputs["prompt"]
prompt_token_ids = await self._tokenize_prompt_async(
parsed_prompt,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = inputs.get("multi_modal_data")
else:
assert_never(inputs)
return prompt, prompt_token_ids, multi_modal_data
async def _process_encoder_decoder_prompt_async(
self,
inputs: PromptInputs,
request_id: str,
) -> EncoderDecoderLLMInputs:
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
encoder_comps: PromptComponents
decoder_comps: DecoderPromptComponents
if is_explicit_encoder_decoder_prompt(inputs):
encoder_task = self._extract_prompt_components_async(
inputs["encoder_prompt"],
request_id=request_id,
)
if (decoder_input := inputs["decoder_prompt"]) is None:
encoder_comps = await encoder_task
decoder_comps = None, None, None
else:
decoder_task = self._extract_prompt_components_async(
decoder_input,
request_id=request_id,
)
encoder_comps, decoder_comps = await asyncio.gather(
encoder_task, decoder_task)
else:
encoder_comps = await self._extract_prompt_components_async(
inputs,
request_id=request_id,
)
decoder_comps = None, None, None
return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
async def _process_decoder_only_prompt_async(
self,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs:
if isinstance(inputs, str):
inputs = {"prompt": inputs}
"""Async version of :meth:`_process_decoder_only_prompt`."""
prompt_comps = await self._extract_prompt_components_async(
inputs,
request_id=request_id,
lora_request=lora_request,
)
if "prompt_token_ids" not in inputs:
tokenizer = self.get_tokenizer_group("prompts must be None if "
"skip_tokenizer_init is True")
return self._build_decoder_only_llm_inputs(
prompt_comps,
prompt_adapter_request=prompt_adapter_request,
)
prompt_token_ids = await tokenizer.encode_async(
async def process_model_inputs_async(
self,
inputs: PromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
"""Async version of :meth:`process_model_inputs`."""
if self.is_encoder_decoder_model():
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
model_inputs = await self._process_encoder_decoder_prompt_async(
inputs,
request_id=request_id,
prompt=inputs["prompt"],
lora_request=lora_request)
)
else:
prompt_token_ids = inputs["prompt_token_ids"]
if is_explicit_encoder_decoder_prompt(inputs):
raise ValueError("Cannot pass encoder-decoder prompt "
"to decoder-only models")
if prompt_adapter_request:
prompt_token_ids = [
0
] * prompt_adapter_request.prompt_adapter_num_virtual_tokens + \
prompt_token_ids
# Decoder-only operation
model_inputs = await self._process_decoder_only_prompt_async(
inputs,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt"),
multi_modal_data=inputs.get("multi_modal_data"))
return self.input_processor(llm_inputs)
return self.input_processor(model_inputs)
async def add_request_async(
self,
@ -336,6 +440,7 @@ class _AsyncLLMEngine(LLMEngine):
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
"""Async version of :meth:`add_request`."""
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
@ -343,10 +448,11 @@ class _AsyncLLMEngine(LLMEngine):
arrival_time = time.time()
processed_inputs = await self.process_model_inputs_async(
inputs,
request_id=request_id,
inputs=inputs,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
prompt_adapter_request=prompt_adapter_request,
)
self._add_processed_request(
request_id=request_id,

View File

@ -5,6 +5,8 @@ from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
from typing import Sequence as GenericSequence
from typing import Set, Tuple, Type, TypeVar, Union
from typing_extensions import assert_never
import vllm.envs as envs
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
@ -22,10 +24,12 @@ from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import (INPUT_REGISTRY, LLMInputs, PromptInputs,
get_prompt_type)
from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, LLMInputs,
PromptInputs, SingletonPromptInputs)
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalDataDict
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
RequestOutputFactory)
from vllm.pooling_params import PoolingParams
@ -43,8 +47,7 @@ from vllm.transformers_utils.tokenizer_group import (
AnyTokenizer, BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import (Counter, is_embedding_model_config,
is_encoder_decoder_model_config)
from vllm.utils import Counter
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
@ -66,6 +69,11 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
PromptComponents = Tuple[Optional[str], List[int],
Optional[MultiModalDataDict]]
DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
Optional[MultiModalDataDict]]
class LLMEngine:
"""An LLM engine that receives requests and generates texts.
@ -524,7 +532,7 @@ class LLMEngine:
return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
def _get_decoder_start_token_id(self, ) -> Optional[int]:
def _get_decoder_start_token_id(self) -> Optional[int]:
'''
Obtain the decoder start token id employed by an encoder/decoder
model. Returns None for non-encoder/decoder models or if the
@ -553,7 +561,7 @@ class LLMEngine:
def _add_processed_request(
self,
request_id: str,
processed_inputs: LLMInputs,
processed_inputs: Union[LLMInputs, EncoderDecoderLLMInputs],
params: Union[SamplingParams, PoolingParams],
arrival_time: float,
lora_request: Optional[LoRARequest],
@ -613,11 +621,11 @@ class LLMEngine:
def stop_remote_worker_execution_loop(self) -> None:
self.model_executor.stop_remote_worker_execution_loop()
_LLMInputComponentsType = Tuple[str, List[int], ]
_LLMInputComponentsType = Tuple[str, List[int]]
def _prepare_decoder_input_ids_for_generation(
self,
decoder_input_ids: Optional[List[int]] = None,
decoder_input_ids: Optional[List[int]],
) -> List[int]:
"""
Prepares `decoder_input_ids` for generation with encoder-decoder models.
@ -639,14 +647,13 @@ class LLMEngine:
* Processed token list
"""
decoder_start_token_id: Optional[int] = (
self._get_decoder_start_token_id())
decoder_start_token_id = self._get_decoder_start_token_id()
assert decoder_start_token_id is not None
if decoder_input_ids is None:
# no decoder prompt input ->
# use decoder_start_token_id as decoder_input_ids
(decoder_input_ids) = self._get_default_enc_dec_decoder_prompt()
decoder_input_ids = self._get_default_enc_dec_decoder_prompt()
if (len(decoder_input_ids) == 0
or decoder_input_ids[0] != decoder_start_token_id):
@ -657,12 +664,11 @@ class LLMEngine:
def _tokenize_prompt(
self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[str] = None,
request_id: str,
lora_request: Optional[LoRARequest],
) -> List[int]:
'''
Wrapper around application of the model's
tokenizer.
Wrapper around application of the model's tokenizer.
Arguments:
@ -678,87 +684,72 @@ class LLMEngine:
tokenizer = self.get_tokenizer_group("prompts must be None if "
"skip_tokenizer_init is True")
prompt_token_ids = tokenizer.encode(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
return tokenizer.encode(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
return prompt_token_ids
def _extract_single_prompt_for_enc_dec_input(
def _extract_prompt_components(
self,
inputs: Optional[PromptInputs],
request_id: Optional[str] = None,
ptype: Optional[str] = None,
is_encoder_prompt: bool = False,
) -> Tuple[Optional[str], List[int]]:
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
) -> PromptComponents:
'''
Only for encoder/decoder models:
Extract prompt & prompt_token_ids from any single
encoder or decoder input prompt. For encoder input prompts
in particular, also extract multi-modal data.
This function handles the following scenarios:
1. The user supplied a singleton encoder prompt
& the prompt/prompt-token-ids must be extracted.
2. The user supplied an explicit encoder/decoder
prompt & the prompt/prompt-token-ids must be
extracted from either the encoder and decoder prompts.
For decoder prompts in particular (scenario 2), special
processing is applied to the returned decoder token ids.
Extract the components of any single encoder or decoder input prompt.
Arguments:
* request_id
* ptype: str representation of the input prompt type.
If `ptype` is `None`, assume that the prompt
type is unknown and must be inferred. This is the
case for ExplicitEncoderDecoder sub-prompts.
* inputs: single encoder or decoder input prompt
* is_encoder_prompt: True if encoder input prompt.
If False, decoder prompt tokens
are preprocessed.
* lora_request: this is only valid for decoder prompts
Returns:
* prompt
* prompt_token_ids
* multi_modal_data
'''
prompt_token_ids = None
ptype = (get_prompt_type(inputs) if ptype is None else ptype)
if inputs is None:
prompt = None
elif ptype == 'str':
if isinstance(inputs, str):
prompt = inputs
prompt_token_ids = self._tokenize_prompt(
prompt,
request_id=request_id,
lora_request=lora_request,
)
elif ptype == 'TokensPrompt':
prompt = None
prompt_token_ids = inputs['prompt_token_ids']
multi_modal_data = None
elif isinstance(inputs, dict):
if "prompt_token_ids" in inputs:
prompt = None
prompt_token_ids = inputs["prompt_token_ids"]
else:
# NOTE: This extra assignment is required to pass mypy
prompt = parsed_prompt = inputs["prompt"]
prompt_token_ids = self._tokenize_prompt(
parsed_prompt,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = inputs.get("multi_modal_data")
else:
prompt = inputs['prompt']
prompt_token_ids = self._tokenize_prompt(
prompt,
request_id=request_id,
)
assert_never(inputs)
if not is_encoder_prompt:
# Apply special pre-processing to
# decoder prompts
prompt_token_ids = (self._prepare_decoder_input_ids_for_generation(
prompt_token_ids, ))
return prompt, prompt_token_ids, multi_modal_data
assert prompt_token_ids is not None
def _apply_prompt_adapter(
self,
prompt_token_ids: List[int],
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> List[int]:
if prompt_adapter_request:
prompt_token_ids = (
[0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
+ prompt_token_ids)
return (
prompt,
prompt_token_ids,
)
return prompt_token_ids
def _get_default_enc_dec_decoder_prompt(self, ) -> List[int]:
def _get_default_enc_dec_decoder_prompt(self) -> List[int]:
'''
Specifically for encoder/decoder models:
generate a default decoder prompt for when
@ -792,18 +783,39 @@ class LLMEngine:
bos_token_id = self._get_bos_token_id()
assert bos_token_id is not None
prompt_token_ids: List[int] = [bos_token_id]
return prompt_token_ids
return [bos_token_id]
def _build_enc_dec_llm_inputs(
self,
encoder_comps: PromptComponents,
decoder_comps: DecoderPromptComponents,
) -> EncoderDecoderLLMInputs:
encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps
decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps
if encoder_mm_data is not None or decoder_mm_data is not None:
raise ValueError("Multi-modal encoder-decoder models are "
"not supported yet")
decoder_prompt_ids = (
self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids))
return EncoderDecoderLLMInputs(
prompt_token_ids=decoder_prompt_ids,
prompt=decoder_prompt,
encoder_prompt_token_ids=encoder_prompt_ids,
encoder_prompt=encoder_prompt,
)
def _process_encoder_decoder_prompt(
self,
inputs: PromptInputs,
request_id: Optional[str] = None,
) -> LLMInputs:
request_id: str,
) -> EncoderDecoderLLMInputs:
'''
For encoder/decoder models only:
Process an input prompt
into an `LLMInputs` instance.
Process an input prompt into an
:class:`EncoderDecoderLLMInputs` instance.
There are two types of input prompts:
singleton prompts which carry only the
@ -830,136 +842,103 @@ class LLMEngine:
Returns:
* `LLMInputs` instance
* :class:`EncoderDecoderLLMInputs` instance
'''
ptype = get_prompt_type(inputs)
encoder_comps: PromptComponents
decoder_comps: DecoderPromptComponents
# Obtain encoder and decoder prompt tokens. Note
# that, no matter what, the decoder
# prompt type is unknown.
if ptype == "ExplicitEncoderDecoder":
# If input is explicit encoder/decoder prompt,
# then it remains to be determined what type
# of encoder prompt we have
extracted_encoder_prompt = inputs.get('encoder_prompt')
encoder_ptype = None
# Extract decoder prompt from explicit
# encoder/decoder prompt
extracted_decoder_prompt = inputs.get('decoder_prompt')
if is_explicit_encoder_decoder_prompt(inputs):
encoder_comps = self._extract_prompt_components(
inputs["encoder_prompt"],
request_id=request_id,
)
if (decoder_input := inputs["decoder_prompt"]) is None:
decoder_comps = None, None, None
else:
decoder_comps = self._extract_prompt_components(
decoder_input,
request_id=request_id,
)
else:
# If input is singleton encoder prompt, then
# we know the encoder prompt type
extracted_encoder_prompt = inputs
encoder_ptype = ptype
# Decoder prompt is always unknown if
# encoder/decoder prompt is not explicit
extracted_decoder_prompt = None
encoder_comps = self._extract_prompt_components(
inputs,
request_id=request_id,
)
# Invoke helper function to obtain encoder
# prompt and prompt token ids, either from
# singleton encoder prompt or from the
# encoder sub-prompt of an explicit
# encoder/decode scenario 2), special
# processing is applied to the returned decoder token ids
(
encoder_prompt,
encoder_prompt_token_ids,
) = self._extract_single_prompt_for_enc_dec_input(
extracted_encoder_prompt,
request_id=request_id,
ptype=encoder_ptype,
is_encoder_prompt=True,
)
decoder_comps = None, None, None
# Invoke helper method to obtain
# decoder prompt and prompt token ids.
#
# The helper method will detect the decoder
# prompt type.
#
# Helper method will also apply special
# preprocessing unique to decoder prompts.
(
decoder_prompt,
decoder_prompt_token_ids,
) = self._extract_single_prompt_for_enc_dec_input(
extracted_decoder_prompt,
request_id=request_id,
ptype=None,
is_encoder_prompt=False,
)
return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
return LLMInputs(
prompt_token_ids=decoder_prompt_token_ids,
prompt=decoder_prompt,
encoder_prompt_token_ids=encoder_prompt_token_ids,
encoder_prompt=encoder_prompt,
)
def _build_decoder_only_llm_inputs(
self,
prompt_comps: PromptComponents,
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> LLMInputs:
prompt, prompt_token_ids, multi_modal_data = prompt_comps
prompt_token_ids = self._apply_prompt_adapter(
prompt_token_ids, prompt_adapter_request=prompt_adapter_request)
return LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=prompt,
multi_modal_data=multi_modal_data)
def _process_decoder_only_prompt(
self,
inputs: PromptInputs,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
request_id: Optional[str] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs:
'''
For decoder-only models:
Process an input prompt
into an `LLMInputs` instance.
Process an input prompt into an :class:`LLMInputs` instance.
Arguments:
* inputs: input prompt
* lora_request
* request_id
* lora_request
* prompt_adapter_request
Returns:
* `LLMInputs` instance
* :class:`LLMInputs` instance
'''
if isinstance(inputs, str):
inputs = {"prompt": inputs}
prompt = inputs.get("prompt")
prompt_comps = self._extract_prompt_components(
inputs,
request_id=request_id,
lora_request=lora_request,
)
if "prompt_token_ids" not in inputs:
prompt_token_ids = self._tokenize_prompt(
prompt,
request_id=request_id,
lora_request=lora_request,
)
else:
prompt_token_ids = inputs["prompt_token_ids"]
if prompt_adapter_request:
prompt_token_ids = (
[0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
+ prompt_token_ids)
return LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=prompt,
multi_modal_data=inputs.get("multi_modal_data"))
return self._build_decoder_only_llm_inputs(
prompt_comps,
prompt_adapter_request=prompt_adapter_request,
)
def process_model_inputs(
self,
request_id: str,
inputs: PromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs:
) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
if self.is_encoder_decoder_model():
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
model_inputs = self._process_encoder_decoder_prompt(
inputs,
request_id=request_id,
)
else:
if is_explicit_encoder_decoder_prompt(inputs):
raise ValueError("Cannot pass encoder-decoder prompt "
"to decoder-only models")
# Decoder-only operation
model_inputs = self._process_decoder_only_prompt(
inputs,
@ -1029,10 +1008,11 @@ class LLMEngine:
arrival_time = time.time()
processed_inputs = self.process_model_inputs(
inputs,
request_id=request_id,
inputs=inputs,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
prompt_adapter_request=prompt_adapter_request,
)
self._add_processed_request(
request_id=request_id,
@ -1597,7 +1577,7 @@ class LLMEngine:
seq_span.set_attribute(SpanAttributes.LLM_LATENCY_E2E, e2e_time)
def is_encoder_decoder_model(self):
return is_encoder_decoder_model_config(self.model_config)
return self.model_config.is_encoder_decoder_model
def is_embedding_model(self):
return is_embedding_model_config(self.model_config)
return self.model_config.is_embedding_model

View File

@ -2,8 +2,7 @@ import codecs
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import (Any, Awaitable, Iterable, List, Optional, Tuple, Union,
cast, final)
from typing import Any, Awaitable, Iterable, List, Optional, Tuple, Union, cast
# yapf conflicts with isort for this block
# yapf: disable
@ -59,7 +58,7 @@ ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam,
CustomChatCompletionMessageParam]
@final # So that it should be compatible with Dict[str, str]
# TODO: Make fields ReadOnly once mypy supports it
class ConversationMessage(TypedDict):
role: str
content: str

View File

@ -6,8 +6,8 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.inputs import (PromptInputs, TextPrompt, TokensPrompt,
parse_and_batch_prompt)
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (

View File

@ -40,9 +40,11 @@ def _get_allowed_token_ids_logits_processor(
return AllowedTokenIdsLogitsProcessor(allowed_token_ids)
def logit_bias_logits_processor(logit_bias: Dict[str,
float], token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
def logit_bias_logits_processor(
logit_bias: Dict[int, float],
token_ids: List[int],
logits: torch.Tensor,
) -> torch.Tensor:
for token_id, bias in logit_bias.items():
logits[token_id] += bias
return logits

View File

@ -22,7 +22,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
TokenizeCompletionRequest,
TokenizeRequest)
# yapf: enable
from vllm.inputs import parse_and_batch_prompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (

View File

@ -1,7 +1,7 @@
from .data import (ExplicitEncoderDecoderPrompt, LLMInputs, ParsedText,
ParsedTokens, PromptInputs, SingletonPromptInputs,
TextPrompt, TokensPrompt, get_prompt_type,
is_valid_encoder_decoder_llm_inputs, parse_and_batch_prompt)
from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt,
TokensPrompt, build_explicit_enc_dec_prompt,
to_enc_dec_tuple_list, zip_enc_dec_prompts)
from .registry import InputContext, InputRegistry
INPUT_REGISTRY = InputRegistry()
@ -14,18 +14,17 @@ See also:
"""
__all__ = [
"ParsedText",
"ParsedTokens",
"parse_and_batch_prompt",
"TextPrompt",
"TokensPrompt",
"PromptInputs",
"SingletonPromptInputs",
"ExplicitEncoderDecoderPrompt",
"LLMInputs",
"EncoderDecoderLLMInputs",
"build_explicit_enc_dec_prompt",
"to_enc_dec_tuple_list",
"zip_enc_dec_prompts",
"INPUT_REGISTRY",
"InputContext",
"InputRegistry",
"get_prompt_type",
"is_valid_encoder_decoder_llm_inputs",
"ExplicitEncoderDecoderPrompt",
"SingletonPromptInputs",
]

View File

@ -1,71 +1,12 @@
from typing import (TYPE_CHECKING, List, Literal, Optional, Sequence,
TypedDict, Union, cast, overload)
from typing import (TYPE_CHECKING, Generic, Iterable, List, Optional, Tuple,
Union)
from typing_extensions import NotRequired
from typing_extensions import NotRequired, TypedDict, TypeVar
if TYPE_CHECKING:
from vllm.multimodal import MultiModalDataDict
class ParsedText(TypedDict):
content: str
is_tokens: Literal[False]
class ParsedTokens(TypedDict):
content: List[int]
is_tokens: Literal[True]
# https://github.com/vllm-project/vllm/pull/4028
@overload
def parse_and_batch_prompt(
prompt: Union[str, List[str]]) -> Sequence[ParsedText]:
...
@overload
def parse_and_batch_prompt(
prompt: Union[List[int], List[List[int]]]) -> Sequence[ParsedTokens]:
...
def parse_and_batch_prompt(
prompt: Union[str, List[str], List[int], List[List[int]]],
) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]:
if isinstance(prompt, str):
# case 1: a string
return [ParsedText(content=prompt, is_tokens=False)]
if isinstance(prompt, list):
if len(prompt) == 0:
raise ValueError("please provide at least one prompt")
if isinstance(prompt[0], str):
# case 2: array of strings
return [
ParsedText(content=elem, is_tokens=False)
for elem in cast(List[str], prompt)
]
if isinstance(prompt[0], int):
# case 3: array of tokens
elem = cast(List[int], prompt)
return [ParsedTokens(content=elem, is_tokens=True)]
if isinstance(prompt[0], list):
if len(prompt[0]) == 0:
raise ValueError("please provide at least one prompt")
if isinstance(prompt[0][0], int):
# case 4: array of token arrays
return [
ParsedTokens(content=elem, is_tokens=True)
for elem in cast(List[List[int]], prompt)
]
raise ValueError("prompt must be a string, array of strings, "
"array of tokens, or array of token arrays")
class TextPrompt(TypedDict):
"""Schema for a text prompt."""
@ -103,39 +44,49 @@ Note that "singleton" is as opposed to a data structure
which encapsulates multiple prompts, i.e. of the sort
which may be utilized for encoder/decoder models when
the user desires to express both the encoder & decoder
prompts explicitly, i.e. ExplicitEncoderDecoderPrompt
prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt`
A prompt of type SingletonPromptInputs may be employed
A prompt of type :class:`SingletonPromptInputs` may be employed
as (1) input to a decoder-only model, (2) input to
the encoder of an encoder/decoder model, in the scenario
where the decoder-prompt is not specified explicitly, or
(3) as a member of a larger data structure encapsulating
more than one prompt, i.e. ExplicitEncoderDecoderPrompt
more than one prompt, i.e. :class:`ExplicitEncoderDecoderPrompt`
"""
_T1_co = TypeVar("_T1_co",
bound=SingletonPromptInputs,
default=SingletonPromptInputs,
covariant=True)
_T2_co = TypeVar("_T2_co",
bound=SingletonPromptInputs,
default=SingletonPromptInputs,
covariant=True)
class ExplicitEncoderDecoderPrompt(TypedDict):
# TODO: Make fields ReadOnly once mypy supports it
class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
"""Represents an encoder/decoder model input prompt,
comprising an explicit encoder prompt and a
decoder prompt.
The encoder and decoder prompts, respectively,
may formatted according to any of the
SingletonPromptInputs schemas, and are not
:class:`SingletonPromptInputs` schemas, and are not
required to have the same schema.
Only the encoder prompt may have multi-modal data.
Note that an ExplicitEncoderDecoderPrompt may not
Note that an :class:`ExplicitEncoderDecoderPrompt` may not
be used as an input to a decoder-only model,
and that the `encoder_prompt` and `decoder_prompt`
fields of this data structure may not themselves
must be SingletonPromptInputs instances.
fields of this data structure themselves must be
:class:`SingletonPromptInputs` instances.
"""
encoder_prompt: SingletonPromptInputs
encoder_prompt: _T1_co
decoder_prompt: SingletonPromptInputs
decoder_prompt: Optional[_T2_co]
PromptInputs = Union[SingletonPromptInputs, ExplicitEncoderDecoderPrompt]
@ -150,60 +101,12 @@ both decoder-only and encoder/decoder input types:
"""
def _has_required_keys(
d: dict,
required_keys: set,
) -> bool:
return required_keys.issubset(d.keys())
def get_prompt_type(prompt: Optional[PromptInputs]) -> Optional[str]:
"""
Get the type-name of the prompt argument instance, given that
isinstance() cannot apply to TypedDict subclasses directly.
If the prompt is None, return 'None' as the type name.
Arguments:
* prompt: LLM input prompt or None
Returns:
* String representation of prompt type
"""
if prompt is None:
return 'None'
required_keys_dict = {
'TextPrompt': {'prompt'},
'TokensPrompt': {'prompt_token_ids'},
'ExplicitEncoderDecoder': {'encoder_prompt', 'decoder_prompt'},
}
if isinstance(prompt, dict):
for (ptype, required_keys) in required_keys_dict.items():
# Ignore type checking in the conditional below because type
# checker does not understand that is_dict(prompt) narrows
# down the possible types
if _has_required_keys(
prompt, # type: ignore
required_keys):
return ptype
raise ValueError(f"Invalid prompt {prompt}, valid types are "
"required_keys_dict={required_keys_dict}")
if isinstance(prompt, str):
return "str"
raise ValueError(f"Invalid prompt {prompt}")
class LLMInputs(TypedDict):
"""
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
This specifies the data required for decoder-only models.
"""
prompt_token_ids: List[int]
"""The token IDs of the prompt."""
@ -213,7 +116,21 @@ class LLMInputs(TypedDict):
The original prompt text corresponding to the token IDs, if available.
"""
encoder_prompt_token_ids: NotRequired[List[int]]
multi_modal_data: NotRequired[Optional["MultiModalDataDict"]]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
class EncoderDecoderLLMInputs(LLMInputs):
"""
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
This specifies the required data for encoder-decoder models.
"""
encoder_prompt_token_ids: List[int]
"""The token IDs of the encoder prompt."""
encoder_prompt: NotRequired[Optional[str]]
@ -222,20 +139,40 @@ class LLMInputs(TypedDict):
available.
"""
multi_modal_data: NotRequired[Optional["MultiModalDataDict"]]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
_T1 = TypeVar("_T1",
bound=SingletonPromptInputs,
default=SingletonPromptInputs)
_T2 = TypeVar("_T2",
bound=SingletonPromptInputs,
default=SingletonPromptInputs)
def is_valid_encoder_decoder_llm_inputs(inputs: LLMInputs) -> bool:
"""
Return True if the LLMInputs instance has the correct configuration
for encoder/decoder.
"""
def build_explicit_enc_dec_prompt(
encoder_prompt: _T1,
decoder_prompt: Optional[_T2],
) -> ExplicitEncoderDecoderPrompt[_T1, _T2]:
return ExplicitEncoderDecoderPrompt(encoder_prompt=encoder_prompt,
decoder_prompt=decoder_prompt)
# True if encoder prompt token ids field exists &
# is not None
return ('encoder_prompt_token_ids' in inputs
and inputs['encoder_prompt_token_ids'] is not None)
def zip_enc_dec_prompts(
enc_prompts: Iterable[_T1],
dec_prompts: Iterable[Optional[_T2]],
) -> List[ExplicitEncoderDecoderPrompt[_T1, _T2]]:
"""
Zip encoder and decoder prompts together into a list of
:class:`ExplicitEncoderDecoderPrompt` instances.
"""
return [
build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt)
for (encoder_prompt, decoder_prompt) in zip(enc_prompts, dec_prompts)
]
def to_enc_dec_tuple_list(
enc_dec_prompts: Iterable[ExplicitEncoderDecoderPrompt[_T1, _T2]],
) -> List[Tuple[_T1, Optional[_T2]]]:
return [(enc_dec_prompt["encoder_prompt"],
enc_dec_prompt["decoder_prompt"])
for enc_dec_prompt in enc_dec_prompts]

75
vllm/inputs/parse.py Normal file
View File

@ -0,0 +1,75 @@
from typing import List, Literal, Sequence, TypedDict, Union, overload
from typing_extensions import TypeIs
from vllm.utils import is_list_of
from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
LLMInputs, PromptInputs)
class ParsedText(TypedDict):
content: str
is_tokens: Literal[False]
class ParsedTokens(TypedDict):
content: List[int]
is_tokens: Literal[True]
@overload
def parse_and_batch_prompt(
prompt: Union[str, List[str]]) -> Sequence[ParsedText]:
...
@overload
def parse_and_batch_prompt(
prompt: Union[List[int], List[List[int]]]) -> Sequence[ParsedTokens]:
...
def parse_and_batch_prompt(
prompt: Union[str, List[str], List[int], List[List[int]]],
) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]:
if isinstance(prompt, str):
# case 1: a string
return [ParsedText(content=prompt, is_tokens=False)]
if isinstance(prompt, list):
if len(prompt) == 0:
raise ValueError("please provide at least one prompt")
if is_list_of(prompt, str):
# case 2: array of strings
return [
ParsedText(content=elem, is_tokens=False) for elem in prompt
]
if is_list_of(prompt, int):
# case 3: array of tokens
return [ParsedTokens(content=prompt, is_tokens=True)]
if is_list_of(prompt, list):
if len(prompt[0]) == 0:
raise ValueError("please provide at least one prompt")
if is_list_of(prompt[0], int):
# case 4: array of token arrays
return [
ParsedTokens(content=elem, is_tokens=True)
for elem in prompt
]
raise ValueError("prompt must be a string, array of strings, "
"array of tokens, or array of token arrays")
def is_explicit_encoder_decoder_prompt(
inputs: PromptInputs) -> TypeIs[ExplicitEncoderDecoderPrompt]:
return isinstance(inputs, dict) and "encoder_prompt" in inputs
def is_valid_encoder_decoder_llm_inputs(
inputs: Union[LLMInputs, EncoderDecoderLLMInputs],
) -> TypeIs[EncoderDecoderLLMInputs]:
return "encoder_prompt_token_ids" in inputs

View File

@ -1,7 +1,7 @@
from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type,
Union, overload, runtime_checkable)
from typing_extensions import TypeGuard
from typing_extensions import TypeIs
from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
from vllm.logger import init_logger
@ -37,18 +37,18 @@ class _SupportsVisionType(Protocol):
@overload
def supports_vision(model: Type[object]) -> TypeGuard[Type[SupportsVision]]:
def supports_vision(model: Type[object]) -> TypeIs[Type[SupportsVision]]:
...
@overload
def supports_vision(model: object) -> TypeGuard[SupportsVision]:
def supports_vision(model: object) -> TypeIs[SupportsVision]:
...
def supports_vision(
model: Union[Type[object], object],
) -> Union[TypeGuard[Type[SupportsVision]], TypeGuard[SupportsVision]]:
) -> Union[TypeIs[Type[SupportsVision]], TypeIs[SupportsVision]]:
if isinstance(model, type):
return isinstance(model, _SupportsVisionType)
@ -94,18 +94,18 @@ class _SupportsLoRAType(Protocol):
@overload
def supports_lora(model: Type[object]) -> TypeGuard[Type[SupportsLoRA]]:
def supports_lora(model: Type[object]) -> TypeIs[Type[SupportsLoRA]]:
...
@overload
def supports_lora(model: object) -> TypeGuard[SupportsLoRA]:
def supports_lora(model: object) -> TypeIs[SupportsLoRA]:
...
def supports_lora(
model: Union[Type[object], object],
) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]:
) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
result = _supports_lora(model)
if not result:
@ -137,7 +137,7 @@ def supports_lora(
def _supports_lora(
model: Union[Type[object], object],
) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]:
) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
if isinstance(model, type):
return isinstance(model, _SupportsLoRAType)
@ -172,18 +172,18 @@ class _HasInnerStateType(Protocol):
@overload
def has_inner_state(model: object) -> TypeGuard[HasInnerState]:
def has_inner_state(model: object) -> TypeIs[HasInnerState]:
...
@overload
def has_inner_state(model: Type[object]) -> TypeGuard[Type[HasInnerState]]:
def has_inner_state(model: Type[object]) -> TypeIs[Type[HasInnerState]]:
...
def has_inner_state(
model: Union[Type[object], object]
) -> Union[TypeGuard[Type[HasInnerState]], TypeGuard[HasInnerState]]:
) -> Union[TypeIs[Type[HasInnerState]], TypeIs[HasInnerState]]:
if isinstance(model, type):
return isinstance(model, _HasInnerStateType)

View File

@ -10,6 +10,7 @@ from vllm.inputs.registry import InputContext
from vllm.logger import init_logger
from vllm.transformers_utils.image_processor import get_image_processor
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import is_list_of
from .base import MultiModalInputs, MultiModalPlugin
@ -113,7 +114,8 @@ class ImagePlugin(MultiModalPlugin):
def _default_input_mapper(self, ctx: InputContext,
data: object) -> MultiModalInputs:
model_config = ctx.model_config
if isinstance(data, (Image.Image, list)):
if isinstance(data, Image.Image) or is_list_of(data, Image.Image):
image_processor = self._get_hf_image_processor(model_config)
if image_processor is None:
raise RuntimeError("No HuggingFace processor is available "
@ -127,7 +129,7 @@ class ImagePlugin(MultiModalPlugin):
raise
return MultiModalInputs(batch_data)
elif isinstance(data, torch.Tensor):
elif isinstance(data, torch.Tensor) or is_list_of(data, torch.Tensor):
raise NotImplementedError("Embeddings input is not supported yet")
raise TypeError(f"Invalid image type: {type(data)}")

View File

@ -11,7 +11,7 @@ from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple,
import torch
from vllm.inputs import is_valid_encoder_decoder_llm_inputs
from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest

View File

@ -17,8 +17,8 @@ from collections import defaultdict
from functools import lru_cache, partial, wraps
from platform import uname
from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar,
Union, overload)
Hashable, List, Literal, Optional, OrderedDict, Set, Tuple,
Type, TypeVar, Union, overload)
from uuid import uuid4
import numpy as np
@ -26,12 +26,10 @@ import numpy.typing as npt
import psutil
import torch
import torch.types
from typing_extensions import ParamSpec
from typing_extensions import ParamSpec, TypeIs, assert_never
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.inputs import (ExplicitEncoderDecoderPrompt, PromptInputs,
SingletonPromptInputs)
from vllm.logger import enable_trace_function_call, init_logger
logger = init_logger(__name__)
@ -812,6 +810,24 @@ def get_dtype_size(dtype: torch.dtype) -> int:
return torch.tensor([], dtype=dtype).element_size()
# `collections` helpers
def is_list_of(
value: object,
typ: Type[T],
*,
check: Literal["first", "all"] = "first",
) -> TypeIs[List[T]]:
if not isinstance(value, list):
return False
if check == "first":
return len(value) == 0 or isinstance(value[0], typ)
elif check == "all":
return all(isinstance(v, typ) for v in value)
assert_never(check)
def merge_dicts(dict1: Dict[K, List[T]],
dict2: Dict[K, List[T]]) -> Dict[K, List[T]]:
"""Merge 2 dicts that have key -> List of items.
@ -959,6 +975,7 @@ def enable_trace_function_call_for_thread() -> None:
enable_trace_function_call(log_path)
# `functools` helpers
def identity(value: T) -> T:
return value
@ -1080,50 +1097,3 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
"""Utility function to run async task in a lock"""
async with lock:
return await task(*args, **kwargs)
def is_encoder_decoder_model_config(model_config) -> bool:
'''
Extract the HF encoder/decoder model flag from the ModelConfig instance.
Return False if model_config is None.
'''
return model_config is not None and \
getattr(model_config.hf_config,
"is_encoder_decoder",
False)
def is_embedding_model_config(model_config) -> bool:
'''
Extract the embedding model flag from the ModelConfig instance.
Return False if model_config is None.
'''
return model_config is not None and \
model_config.embedding_mode
def build_explicit_enc_dec_prompt(
encoder_prompt: SingletonPromptInputs,
decoder_prompt: SingletonPromptInputs,
) -> ExplicitEncoderDecoderPrompt:
return ExplicitEncoderDecoderPrompt(encoder_prompt=encoder_prompt,
decoder_prompt=decoder_prompt)
def zip_enc_dec_prompt_lists(
enc_prompt_list: List[SingletonPromptInputs],
dec_prompt_list: List[SingletonPromptInputs],
) -> List[ExplicitEncoderDecoderPrompt]:
return [
build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt)
for (encoder_prompt,
decoder_prompt) in zip(enc_prompt_list, dec_prompt_list)
]
def to_enc_dec_tuple_list(
enc_dec_prompts: List[ExplicitEncoderDecoderPrompt],
) -> List[Tuple[PromptInputs, PromptInputs]]:
return [(enc_dec_prompt['encoder_prompt'],
enc_dec_prompt['decoder_prompt'])
for enc_dec_prompt in enc_dec_prompts]

View File

@ -19,8 +19,6 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.platforms import current_platform
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (is_embedding_model_config,
is_encoder_decoder_model_config)
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
@ -113,10 +111,10 @@ class Worker(LocalOrDistributedWorkerBase):
self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
def _is_encoder_decoder_model(self):
return is_encoder_decoder_model_config(self.model_config)
return self.model_config.is_encoder_decoder_model
def _is_embedding_model(self):
return is_embedding_model_config(self.model_config)
return self.model_config.is_embedding_model
def init_device(self) -> None:
if self.device_config.device.type == "cuda":